Source code for sqldbagent.adapters.langgraph.agent

"""LangGraph agent builders over the shared sqldbagent services."""

from __future__ import annotations

from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager, contextmanager
from typing import Any, Iterator

from sqldbagent.adapters.langchain.tools import create_langchain_tools
from sqldbagent.adapters.langgraph.checkpoint import (
    create_async_postgres_checkpointer,
    create_sync_postgres_checkpointer,
)
from sqldbagent.adapters.langgraph.middleware import create_sqldbagent_middleware
from sqldbagent.adapters.langgraph.prompts import create_sqldbagent_system_prompt
from sqldbagent.adapters.langgraph.state import SQLDBAgentContext, SQLDBAgentState
from sqldbagent.adapters.shared import require_dependency
from sqldbagent.core.bootstrap import ServiceContainer
from sqldbagent.core.config import AppSettings, load_settings


[docs] def create_sqldbagent_agent( *, services: ServiceContainer, model: str | Any, datasource_name: str, settings: AppSettings | None = None, schema_name: str | None = None, checkpointer: Any | None = None, middleware: Sequence[Any] = (), include_default_middleware: bool = True, interrupt_before: list[str] | None = None, interrupt_after: list[str] | None = None, debug: bool = False, ) -> Any: """Create a LangChain v1 agent compiled on LangGraph. Args: services: Shared sqldbagent service container. model: LangChain-compatible model instance or model identifier. datasource_name: Datasource identifier. settings: Optional application settings. schema_name: Optional schema focus. checkpointer: Optional LangGraph checkpointer. middleware: Optional additional LangChain middleware chain. include_default_middleware: Whether to prepend sqldbagent's default middleware. interrupt_before: Optional LangGraph interrupt hook points. interrupt_after: Optional LangGraph interrupt hook points. debug: Whether LangGraph debug mode should be enabled. Returns: Any: Compiled LangGraph agent. """ resolved_settings = settings or load_settings() agents_module = require_dependency("langchain.agents", "langchain") resolved_middleware = list(middleware) if include_default_middleware: resolved_middleware = [ *create_sqldbagent_middleware( datasource_name=datasource_name, settings=resolved_settings, schema_name=schema_name, ), *resolved_middleware, ] return agents_module.create_agent( model=model, tools=create_langchain_tools(services), system_prompt=( None if include_default_middleware else create_sqldbagent_system_prompt( datasource_name=datasource_name, settings=resolved_settings, schema_name=schema_name, ) ), middleware=resolved_middleware, state_schema=SQLDBAgentState, context_schema=SQLDBAgentContext, checkpointer=checkpointer, interrupt_before=interrupt_before, interrupt_after=interrupt_after, debug=debug, name=resolved_settings.agent.name, )
[docs] @contextmanager def create_sync_postgres_checkpointed_agent( *, services: ServiceContainer, model: str | Any, datasource_name: str, settings: AppSettings | None = None, schema_name: str | None = None, middleware: Sequence[Any] = (), include_default_middleware: bool = True, interrupt_before: list[str] | None = None, interrupt_after: list[str] | None = None, debug: bool = False, ) -> Iterator[Any]: """Create a sync LangGraph agent with Postgres-backed checkpointing. Args: services: Shared sqldbagent service container. model: LangChain-compatible model instance or model identifier. datasource_name: Datasource identifier. settings: Optional application settings. schema_name: Optional schema focus. middleware: Optional additional LangChain middleware chain. include_default_middleware: Whether to prepend sqldbagent's default middleware. interrupt_before: Optional LangGraph interrupt hook points. interrupt_after: Optional LangGraph interrupt hook points. debug: Whether LangGraph debug mode should be enabled. Yields: Any: Compiled LangGraph agent. """ resolved_settings = settings or load_settings() with create_sync_postgres_checkpointer(settings=resolved_settings) as checkpointer: yield create_sqldbagent_agent( services=services, model=model, datasource_name=datasource_name, settings=resolved_settings, schema_name=schema_name, checkpointer=checkpointer, middleware=middleware, include_default_middleware=include_default_middleware, interrupt_before=interrupt_before, interrupt_after=interrupt_after, debug=debug, )
[docs] @asynccontextmanager async def create_async_postgres_checkpointed_agent( *, services: ServiceContainer, model: str | Any, datasource_name: str, settings: AppSettings | None = None, schema_name: str | None = None, middleware: Sequence[Any] = (), include_default_middleware: bool = True, interrupt_before: list[str] | None = None, interrupt_after: list[str] | None = None, debug: bool = False, ) -> AsyncIterator[Any]: """Create an async LangGraph agent with Postgres-backed checkpointing. Args: services: Shared sqldbagent service container. model: LangChain-compatible model instance or model identifier. datasource_name: Datasource identifier. settings: Optional application settings. schema_name: Optional schema focus. middleware: Optional additional LangChain middleware chain. include_default_middleware: Whether to prepend sqldbagent's default middleware. interrupt_before: Optional LangGraph interrupt hook points. interrupt_after: Optional LangGraph interrupt hook points. debug: Whether LangGraph debug mode should be enabled. Yields: Any: Compiled LangGraph agent. """ resolved_settings = settings or load_settings() async with create_async_postgres_checkpointer( settings=resolved_settings ) as checkpointer: yield create_sqldbagent_agent( services=services, model=model, datasource_name=datasource_name, settings=resolved_settings, schema_name=schema_name, checkpointer=checkpointer, middleware=middleware, include_default_middleware=include_default_middleware, interrupt_before=interrupt_before, interrupt_after=interrupt_after, debug=debug, )