"""Snapshot retrieval service backed by Qdrant."""
from __future__ import annotations
from pathlib import Path
from re import sub
from typing import Any
import orjson
from sqldbagent.adapters.shared import require_dependency
from sqldbagent.core.config import (
ArtifactSettings,
EmbeddingSettings,
LLMSettings,
RetrievalSettings,
)
from sqldbagent.docs.service import SnapshotDocumentService
from sqldbagent.retrieval.embeddings import build_embeddings
from sqldbagent.retrieval.models import (
RetrievalIndexManifestModel,
RetrievalResultModel,
RetrievedDocumentModel,
)
from sqldbagent.snapshot.models import SnapshotBundleModel
from sqldbagent.snapshot.service import SnapshotService
[docs]
class SnapshotRetrievalService:
"""Index and retrieve snapshot documents through Qdrant."""
_PAYLOAD_INDEX_FIELDS = (
"metadata.datasource_name",
"metadata.schema_name",
"metadata.snapshot_id",
"metadata.artifact_type",
"metadata.table_name",
"metadata.view_name",
"metadata.entity_kind",
"metadata.source_table",
"metadata.target_table",
)
[docs]
def __init__(
self,
*,
datasource_name: str,
snapshotter: SnapshotService,
document_service: SnapshotDocumentService,
artifacts: ArtifactSettings,
embeddings_settings: EmbeddingSettings,
llm_settings: LLMSettings,
retrieval_settings: RetrievalSettings,
embeddings: Any | None = None,
client: Any | None = None,
) -> None:
"""Initialize the retrieval service.
Args:
datasource_name: Datasource identifier.
snapshotter: Snapshot service used to load latest snapshots.
document_service: Service used to export snapshot documents.
artifacts: Artifact directory settings.
embeddings_settings: Embedding backend settings.
llm_settings: Provider API settings.
retrieval_settings: Vectorstore settings.
embeddings: Optional explicit embeddings backend override.
client: Optional explicit Qdrant client override.
"""
self._datasource_name = datasource_name
self._snapshotter = snapshotter
self._document_service = document_service
self._artifacts = artifacts
self._embeddings_settings = embeddings_settings
self._llm_settings = llm_settings
self._retrieval_settings = retrieval_settings
self._embeddings = embeddings
self._client = client
[docs]
def index_snapshot_bundle(
self,
bundle: SnapshotBundleModel,
*,
recreate_collection: bool = False,
) -> RetrievalIndexManifestModel:
"""Index one snapshot bundle into Qdrant.
Args:
bundle: Snapshot bundle to index.
recreate_collection: Whether to recreate the collection first.
Returns:
RetrievalIndexManifestModel: Persisted index manifest.
"""
document_bundle = self._document_service.create_document_bundle(bundle)
document_bundle_path = self._document_service.save_document_bundle(
document_bundle
)
vector_store = self._ensure_vector_store(
recreate_collection=recreate_collection
)
documents = self._document_service.export_langchain_documents(document_bundle)
document_ids = [document.document_id for document in document_bundle.documents]
vector_store.add_documents(documents=documents, ids=document_ids)
manifest = RetrievalIndexManifestModel(
datasource_name=self._datasource_name,
schema_name=bundle.regenerate.schema_name,
snapshot_id=bundle.snapshot_id,
collection_name=self._collection_name,
document_bundle_path=document_bundle_path.as_posix(),
document_count=len(document_bundle.documents),
embedding_provider=self._embeddings_settings.provider,
embedding_model=self._embeddings_settings.model,
summary=(
f"Indexed {len(document_bundle.documents)} documents for datasource "
f"'{self._datasource_name}' schema '{bundle.regenerate.schema_name}' "
f"into collection '{self._collection_name}'."
),
)
self._save_manifest(manifest)
return manifest
[docs]
def index_latest_schema_snapshot(
self,
schema_name: str,
*,
recreate_collection: bool = False,
) -> RetrievalIndexManifestModel:
"""Index the latest saved snapshot for one schema.
Args:
schema_name: Schema name to index.
recreate_collection: Whether to recreate the collection first.
Returns:
RetrievalIndexManifestModel: Persisted index manifest.
"""
bundle = SnapshotService.load_latest_snapshot(
self._artifacts,
datasource_name=self._datasource_name,
schema_name=schema_name,
)
return self.index_snapshot_bundle(
bundle,
recreate_collection=recreate_collection,
)
[docs]
def retrieve(
self,
query: str,
*,
schema_name: str | None = None,
table_name: str | None = None,
snapshot_id: str | None = None,
artifact_types: list[str] | None = None,
limit: int | None = None,
) -> RetrievalResultModel:
"""Retrieve relevant schema context from Qdrant.
Args:
query: Retrieval query.
schema_name: Optional schema filter.
table_name: Optional table filter.
snapshot_id: Optional snapshot filter.
artifact_types: Optional artifact-type filters.
limit: Optional result limit override.
Returns:
RetrievalResultModel: Retrieval result payload.
"""
target_snapshot = self._resolve_target_snapshot(
schema_name=schema_name,
snapshot_id=snapshot_id,
)
auto_indexed = False
if target_snapshot is not None and not self._manifest_exists(
schema_name=target_snapshot.regenerate.schema_name,
snapshot_id=target_snapshot.snapshot_id,
):
self.index_snapshot_bundle(
target_snapshot,
recreate_collection=False,
)
auto_indexed = True
vector_store = self._ensure_vector_store(recreate_collection=False)
result_limit = limit or self._retrieval_settings.default_top_k
qdrant_filter = self._build_filter(
schema_name=schema_name,
table_name=table_name,
snapshot_id=snapshot_id,
artifact_types=artifact_types,
)
retrieved = self._search_vector_store(
vector_store=vector_store,
query=query,
qdrant_filter=qdrant_filter,
result_limit=result_limit,
)
if not retrieved and target_snapshot is not None and not auto_indexed:
self.index_snapshot_bundle(
target_snapshot,
recreate_collection=False,
)
auto_indexed = True
vector_store = self._ensure_vector_store(recreate_collection=False)
retrieved = self._search_vector_store(
vector_store=vector_store,
query=query,
qdrant_filter=qdrant_filter,
result_limit=result_limit,
)
summary_prefix = "Auto-indexed snapshot documents. " if auto_indexed else ""
summary = (
f"{summary_prefix}Retrieved {len(retrieved)} documents from collection "
f"'{self._collection_name}' for datasource '{self._datasource_name}'."
)
return RetrievalResultModel(
query=query,
datasource_name=self._datasource_name,
schema_name=schema_name,
table_name=table_name,
snapshot_id=snapshot_id,
collection_name=self._collection_name,
documents=retrieved,
summary=summary,
)
def _search_vector_store(
self,
*,
vector_store: Any,
query: str,
qdrant_filter: Any,
result_limit: int,
) -> list[RetrievedDocumentModel]:
"""Execute one filtered retrieval query against the vector store."""
if self._retrieval_settings.use_mmr:
documents = vector_store.max_marginal_relevance_search(
query,
k=result_limit,
fetch_k=max(result_limit, self._retrieval_settings.default_fetch_k),
filter=qdrant_filter,
score_threshold=self._retrieval_settings.score_threshold,
)
retrieved = [
RetrievedDocumentModel(
document_id=getattr(document, "id", "") or "",
page_content=document.page_content,
metadata=document.metadata,
summary=self._summarize_document(document.page_content),
)
for document in documents
]
else:
documents = vector_store.similarity_search_with_score(
query,
k=result_limit,
filter=qdrant_filter,
score_threshold=self._retrieval_settings.score_threshold,
)
retrieved = [
RetrievedDocumentModel(
document_id=getattr(document, "id", "") or "",
page_content=document.page_content,
metadata=document.metadata,
score=score,
summary=self._summarize_document(document.page_content),
)
for document, score in documents
]
return retrieved
[docs]
@staticmethod
def load_manifest(path: str | Path) -> RetrievalIndexManifestModel:
"""Load a saved retrieval manifest."""
return RetrievalIndexManifestModel.model_validate(
orjson.loads(Path(path).read_bytes())
)
@property
def manifest_dir(self) -> Path:
"""Return the retrieval-manifest root directory."""
return Path(self._artifacts.root_dir) / self._artifacts.vectorstores_dir
@property
def _collection_name(self) -> str:
"""Return the generated Qdrant collection name."""
raw = (
f"{self._retrieval_settings.collection_prefix}__"
f"{self._datasource_name}__"
f"{self._embeddings_settings.provider}__"
f"{self._embeddings_settings.model}"
)
return sub(r"[^a-zA-Z0-9_]+", "_", raw).strip("_").lower()
def _save_manifest(self, manifest: RetrievalIndexManifestModel) -> Path:
"""Persist one retrieval manifest to disk."""
path = self.manifest_path(
schema_name=manifest.schema_name,
snapshot_id=manifest.snapshot_id,
)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(
orjson.dumps(
manifest.model_dump(mode="json"),
option=orjson.OPT_INDENT_2,
)
)
return path
def _manifest_exists(self, *, schema_name: str, snapshot_id: str) -> bool:
"""Return whether a retrieval manifest already exists for one snapshot."""
return self.manifest_path(
schema_name=schema_name,
snapshot_id=snapshot_id,
).exists()
[docs]
def manifest_path(self, *, schema_name: str, snapshot_id: str) -> Path:
"""Return the saved manifest path for one schema snapshot."""
return (
self.manifest_dir
/ self._datasource_name
/ schema_name
/ f"{snapshot_id}.json"
)
[docs]
def load_saved_manifest(
self,
*,
schema_name: str,
snapshot_id: str,
) -> RetrievalIndexManifestModel | None:
"""Load a saved retrieval manifest when one exists."""
path = self.manifest_path(schema_name=schema_name, snapshot_id=snapshot_id)
if not path.exists():
return None
return self.load_manifest(path)
def _ensure_vector_store(self, *, recreate_collection: bool) -> Any:
"""Return a ready-to-use LangChain Qdrant vector store."""
qdrant_module = require_dependency("langchain_qdrant", "langchain-qdrant")
client = self._get_client()
embeddings = self._get_embeddings()
if recreate_collection:
self._create_collection(force_recreate=True)
elif not client.collection_exists(self._collection_name):
self._create_collection(force_recreate=False)
return qdrant_module.QdrantVectorStore(
client=client,
collection_name=self._collection_name,
embedding=embeddings,
)
def _create_collection(self, *, force_recreate: bool) -> None:
"""Create the Qdrant collection if it does not exist."""
qdrant_models = require_dependency("qdrant_client", "qdrant-client").models
vector_size = len(self._get_embeddings().embed_query("sqldbagent bootstrap"))
vector_params = qdrant_models.VectorParams(
size=vector_size,
distance=qdrant_models.Distance.COSINE,
)
client = self._get_client()
if force_recreate:
if client.collection_exists(self._collection_name):
client.delete_collection(self._collection_name)
self._call_qdrant_with_exists_tolerance(
lambda: client.create_collection(
self._collection_name,
vectors_config=vector_params,
on_disk_payload=True,
)
)
elif not client.collection_exists(self._collection_name):
self._call_qdrant_with_exists_tolerance(
lambda: client.create_collection(
self._collection_name,
vectors_config=vector_params,
on_disk_payload=True,
)
)
if self._retrieval_settings.create_payload_indexes:
for field_name in self._PAYLOAD_INDEX_FIELDS:
self._call_qdrant_with_exists_tolerance(
lambda field_name=field_name: client.create_payload_index(
self._collection_name,
field_name,
qdrant_models.PayloadSchemaType.KEYWORD,
)
)
def _resolve_target_snapshot(
self,
*,
schema_name: str | None,
snapshot_id: str | None,
) -> SnapshotBundleModel | None:
"""Resolve the snapshot bundle most relevant to a retrieval request."""
entries = SnapshotService.list_saved_snapshots(
self._artifacts,
datasource_name=self._datasource_name,
schema_name=schema_name,
)
root = SnapshotService._snapshot_dir_from_artifacts(self._artifacts)
if snapshot_id is not None:
for entry in entries:
if entry.snapshot_id == snapshot_id:
return SnapshotService.load_snapshot(root / entry.path)
return None
if entries:
return SnapshotService.load_snapshot(root / entries[0].path)
return None
@staticmethod
def _call_qdrant_with_exists_tolerance(operation: Any) -> None:
"""Run one Qdrant mutation while tolerating duplicate-create races."""
try:
operation()
except Exception as exc: # noqa: BLE001
if SnapshotRetrievalService._is_qdrant_exists_conflict(exc):
return
raise
@staticmethod
def _is_qdrant_exists_conflict(exc: Exception) -> bool:
"""Return whether an exception represents an already-exists conflict."""
normalized = str(exc).lower()
return (
"already exists" in normalized
or "409 (conflict)" in normalized
or "status code 409" in normalized
)
def _build_filter(
self,
*,
schema_name: str | None,
table_name: str | None,
snapshot_id: str | None,
artifact_types: list[str] | None,
) -> Any:
"""Build a Qdrant payload filter for retrieval."""
qdrant_models = require_dependency("qdrant_client", "qdrant-client").models
conditions: list[Any] = [
qdrant_models.FieldCondition(
key="metadata.datasource_name",
match=qdrant_models.MatchValue(value=self._datasource_name),
)
]
if schema_name is not None:
conditions.append(
qdrant_models.FieldCondition(
key="metadata.schema_name",
match=qdrant_models.MatchValue(value=schema_name),
)
)
if table_name is not None:
conditions.append(
qdrant_models.FieldCondition(
key="metadata.table_name",
match=qdrant_models.MatchValue(value=table_name),
)
)
if snapshot_id is not None:
conditions.append(
qdrant_models.FieldCondition(
key="metadata.snapshot_id",
match=qdrant_models.MatchValue(value=snapshot_id),
)
)
if artifact_types:
conditions.append(
qdrant_models.FieldCondition(
key="metadata.artifact_type",
match=qdrant_models.MatchAny(any=artifact_types),
)
)
return qdrant_models.Filter(must=conditions)
def _get_client(self) -> Any:
"""Return the configured Qdrant client."""
if self._client is None:
client_module = require_dependency("qdrant_client", "qdrant-client")
self._client = client_module.QdrantClient(
url=self._retrieval_settings.qdrant_url,
api_key=self._retrieval_settings.qdrant_api_key,
grpc_port=self._retrieval_settings.qdrant_grpc_port,
prefer_grpc=self._retrieval_settings.qdrant_prefer_grpc,
check_compatibility=False,
)
return self._client
def _get_embeddings(self) -> Any:
"""Return the configured embeddings backend."""
if self._embeddings is None:
self._embeddings = build_embeddings(
embeddings_settings=self._embeddings_settings,
llm_settings=self._llm_settings,
artifacts=self._artifacts,
)
return self._embeddings
@staticmethod
def _summarize_document(page_content: str) -> str:
"""Return a short preview for one retrieved document."""
first_line = page_content.splitlines()[0] if page_content else ""
return first_line[:160]