"""LangChain v1 middleware builders for sqldbagent agents."""
from __future__ import annotations
from typing import Any
from sqldbagent.adapters.langgraph.prompts import create_sqldbagent_system_prompt
from sqldbagent.adapters.langgraph.state import SQLDBAgentState
from sqldbagent.adapters.shared import require_dependency
from sqldbagent.core.agent_context import build_sqldbagent_state_seed
from sqldbagent.core.config import AppSettings, load_settings
[docs]
def create_sqldbagent_middleware(
*,
datasource_name: str,
settings: AppSettings | None = None,
schema_name: str | None = None,
) -> list[Any]:
"""Build the default middleware stack for sqldbagent agents.
The middleware stack is where LangChain v1's `create_agent(...)` contract
becomes repo-specific. We use it for:
- dynamic prompt injection from stored snapshots
- bounded model and tool call loops
- structured tool error responses instead of raw exceptions
Args:
datasource_name: Datasource identifier.
settings: Optional application settings.
schema_name: Optional schema focus.
Returns:
list[Any]: LangChain middleware instances in execution order.
"""
resolved_settings = settings or load_settings()
middleware_module = require_dependency("langchain.agents.middleware", "langchain")
middlewares: list[Any] = [
create_sqldbagent_state_middleware(
datasource_name=datasource_name,
settings=resolved_settings,
schema_name=schema_name,
),
create_sqldbagent_dynamic_prompt_middleware(
datasource_name=datasource_name,
settings=resolved_settings,
schema_name=schema_name,
),
]
if resolved_settings.agent.enable_todo_middleware:
middlewares.append(middleware_module.TodoListMiddleware())
if resolved_settings.agent.enable_human_in_the_loop:
middlewares.append(
middleware_module.HumanInTheLoopMiddleware({"safe_query_sql": True})
)
middlewares.extend(
[
create_sqldbagent_tool_error_middleware(),
]
)
if resolved_settings.agent.enable_summarization_middleware:
summarization_middleware = create_sqldbagent_summarization_middleware(
settings=resolved_settings
)
if summarization_middleware is not None:
middlewares.append(summarization_middleware)
middlewares.extend(
[
create_sqldbagent_tool_digest_middleware(
settings=resolved_settings,
)
]
)
if resolved_settings.agent.max_model_calls_per_run is not None:
middlewares.append(
middleware_module.ModelCallLimitMiddleware(
run_limit=resolved_settings.agent.max_model_calls_per_run,
exit_behavior="error",
)
)
if resolved_settings.agent.max_tool_calls_per_run is not None:
middlewares.append(
middleware_module.ToolCallLimitMiddleware(
run_limit=resolved_settings.agent.max_tool_calls_per_run,
exit_behavior="error",
)
)
return middlewares
[docs]
def create_sqldbagent_state_middleware(
*,
datasource_name: str,
settings: AppSettings | None = None,
schema_name: str | None = None,
) -> Any:
"""Seed agent state with snapshot and dashboard-oriented context.
Args:
datasource_name: Datasource identifier.
settings: Optional application settings.
schema_name: Optional schema focus.
Returns:
Any: LangChain middleware instance created via `@before_agent`.
"""
resolved_settings = settings or load_settings()
middleware_module = require_dependency("langchain.agents.middleware", "langchain")
@middleware_module.before_agent(state_schema=SQLDBAgentState)
def sqldbagent_state_seed(_state: SQLDBAgentState, _runtime: Any) -> dict[str, Any]:
return build_sqldbagent_state_seed(
datasource_name=datasource_name,
settings=resolved_settings,
schema_name=schema_name,
)
return sqldbagent_state_seed
[docs]
def create_sqldbagent_dynamic_prompt_middleware(
*,
datasource_name: str,
settings: AppSettings | None = None,
schema_name: str | None = None,
) -> Any:
"""Create dynamic prompt middleware over stored snapshot context.
Args:
datasource_name: Datasource identifier.
settings: Optional application settings.
schema_name: Optional schema focus.
Returns:
Any: LangChain middleware instance created via `@dynamic_prompt`.
"""
resolved_settings = settings or load_settings()
middleware_module = require_dependency("langchain.agents.middleware", "langchain")
@middleware_module.dynamic_prompt
def sqldbagent_dynamic_prompt(_request: Any) -> str:
return create_sqldbagent_system_prompt(
datasource_name=datasource_name,
settings=resolved_settings,
schema_name=schema_name,
)
return sqldbagent_dynamic_prompt
[docs]
def create_sqldbagent_summarization_middleware(
*,
settings: AppSettings | None = None,
) -> Any | None:
"""Create context summarization middleware when configured.
Args:
settings: Optional application settings.
Returns:
Any | None: LangChain summarization middleware when configured.
"""
resolved_settings = settings or load_settings()
middleware_module = require_dependency("langchain.agents.middleware", "langchain")
model_name = resolved_settings.agent.summarization_model or _build_model_reference(
resolved_settings
)
if model_name is None:
return None
return middleware_module.SummarizationMiddleware(
model=model_name,
trigger=("fraction", resolved_settings.agent.summarization_trigger_fraction),
keep=("messages", resolved_settings.agent.summarization_keep_messages),
summary_prompt=_build_summary_prompt(),
)
def _build_model_reference(settings: AppSettings) -> str | None:
"""Build a LangChain model reference from provider settings.
Args:
settings: Application settings.
Returns:
str | None: Provider-qualified model reference when available.
"""
if not settings.llm.default_provider or not settings.llm.default_model:
return None
return f"{settings.llm.default_provider}:{settings.llm.default_model}"
def _compress_tool_messages(messages: list[Any], *, limit: int) -> list[str]:
"""Compress recent tool messages into short digest lines.
Args:
messages: Agent message history.
limit: Maximum digest entries to retain.
Returns:
list[str]: Compact tool-call digest entries.
"""
digest: list[str] = []
for message in messages:
if getattr(message, "type", None) != "tool":
continue
tool_name = getattr(message, "name", "tool")
content = str(getattr(message, "content", "")).replace("\n", " ").strip()
if len(content) > 160:
content = f"{content[:157]}..."
digest.append(f"{tool_name}: {content}")
return digest[-limit:]
def _build_summary_prompt() -> str:
"""Build the repo-specific summarization prompt for long agent sessions.
Returns:
str: Prompt template for LangChain summarization middleware.
"""
return """
<role>
sqldbagent Context Compression Assistant
</role>
<goal>
You are compressing a long-running database intelligence session so the agent can continue working without losing critical context.
</goal>
<instructions>
Summarize only the most important database-specific context and execution history. Keep it concrete and reusable.
Always include these sections:
## OBJECTIVE
What the user is trying to learn or produce.
## DATABASE CONTEXT
Datasource, schema focus, important entities, relationships, row-count/storage/profile hints, and safety constraints.
## SNAPSHOT AND ARTIFACT CONTEXT
Relevant snapshot ids, summaries, docs, diagrams, prompts, or exports already created or loaded.
## TOOL AND QUERY HISTORY
The most important inspection/profile/query actions and what they established.
## OPEN QUESTIONS
What is still unresolved or still needs verification.
## NEXT STEPS
What the agent should do next.
</instructions>
<messages>
Messages to summarize:
{messages}
</messages>
""".strip()