Source code for sqldbagent.retrieval.embeddings
"""Embedding provider helpers."""
from __future__ import annotations
from hashlib import blake2b
from math import sqrt
from pathlib import Path
from re import findall, sub
from typing import Any
from sqldbagent.adapters.shared import require_dependency
from sqldbagent.core.config import ArtifactSettings, EmbeddingSettings, LLMSettings
[docs]
class HashEmbeddings:
"""Deterministic local embeddings for offline tests and smoke flows."""
[docs]
def __init__(self, *, dimensions: int = 256) -> None:
"""Initialize the hash embeddings backend.
Args:
dimensions: Number of output dimensions.
"""
self._dimensions = dimensions
[docs]
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a batch of documents.
Args:
texts: Input texts to embed.
Returns:
list[list[float]]: Deterministic unit vectors.
"""
return [self._embed_text(text) for text in texts]
[docs]
def embed_query(self, text: str) -> list[float]:
"""Embed one query.
Args:
text: Query text.
Returns:
list[float]: Deterministic unit vector.
"""
return self._embed_text(text)
def _embed_text(self, text: str) -> list[float]:
"""Hash one string into a normalized dense vector."""
vector = [0.0] * self._dimensions
tokens = findall(r"[A-Za-z0-9_]+", text.lower()) or [text.lower()]
for token in tokens:
digest = blake2b(token.encode("utf-8"), digest_size=16).digest()
index = int.from_bytes(digest[:4], "big") % self._dimensions
sign = 1.0 if digest[4] % 2 == 0 else -1.0
magnitude = 0.5 + (int.from_bytes(digest[5:9], "big") / 2**32)
vector[index] += sign * magnitude
norm = sqrt(sum(component * component for component in vector))
if norm == 0:
return vector
return [component / norm for component in vector]
[docs]
def build_embeddings(
*,
embeddings_settings: EmbeddingSettings,
llm_settings: LLMSettings,
artifacts: ArtifactSettings,
) -> Any:
"""Build a cached embeddings backend.
Args:
embeddings_settings: Embedding backend settings.
llm_settings: Provider API settings.
artifacts: Artifact directory settings.
Returns:
Any: LangChain-compatible embeddings backend.
"""
underlying = _build_underlying_embeddings(
embeddings_settings=embeddings_settings,
llm_settings=llm_settings,
)
storage_module = require_dependency(
"langchain_classic.storage.file_system",
"langchain",
)
embeddings_module = require_dependency(
"langchain_classic.embeddings.cache",
"langchain",
)
cache_dir = (
Path(artifacts.root_dir)
/ artifacts.embeddings_cache_dir
/ embeddings_settings.provider
)
cache_dir.mkdir(parents=True, exist_ok=True)
namespace_parts = [
embeddings_settings.provider,
embeddings_settings.model,
str(embeddings_settings.dimensions or "default"),
]
namespace = sub(r"[^a-zA-Z0-9_.\-/]+", "_", "__".join(namespace_parts))
return embeddings_module.CacheBackedEmbeddings.from_bytes_store(
underlying_embeddings=underlying,
document_embedding_cache=storage_module.LocalFileStore(cache_dir),
namespace=namespace,
batch_size=embeddings_settings.batch_size,
query_embedding_cache=embeddings_settings.cache_query_embeddings,
key_encoder="sha256",
)
def _build_underlying_embeddings(
*,
embeddings_settings: EmbeddingSettings,
llm_settings: LLMSettings,
) -> Any:
"""Build the non-cached embedding implementation."""
if embeddings_settings.provider == "hash":
return HashEmbeddings(dimensions=embeddings_settings.dimensions or 256)
openai_module = require_dependency("langchain_openai", "langchain-openai")
return openai_module.OpenAIEmbeddings(
model=embeddings_settings.model,
dimensions=embeddings_settings.dimensions,
api_key=llm_settings.openai_api_key,
base_url=llm_settings.openai_base_url,
chunk_size=embeddings_settings.batch_size,
)