Source code for sqldbagent.safety.execution

"""Guarded sync and async SQL execution services."""

from __future__ import annotations

from time import perf_counter

from sqlalchemy import Engine, text
from sqlalchemy.ext.asyncio import AsyncEngine

from sqldbagent.core.errors import ConfigurationError
from sqldbagent.core.models.query import QueryExecutionResult
from sqldbagent.core.serialization import to_jsonable
from sqldbagent.safety.guard import QueryGuardService


[docs] class SafeQueryService: """Execute guarded SQL after it passes the safety layer."""
[docs] def __init__( self, *, engine: Engine, guard: QueryGuardService, async_engine: AsyncEngine | None = None, write_engine: Engine | None = None, write_async_engine: AsyncEngine | None = None, ) -> None: """Initialize the safe query service. Args: engine: Sync SQLAlchemy engine. guard: Shared SQL guard service. async_engine: Optional async SQLAlchemy engine. write_engine: Optional writable sync engine used only when writable access is requested explicitly and allowed by policy. write_async_engine: Optional writable async engine used only when writable access is requested explicitly and allowed by policy. """ self._engine = engine self._guard = guard self._async_engine = async_engine self._write_engine = write_engine self._write_async_engine = write_async_engine
[docs] def run( self, sql: str, *, max_rows: int | None = None, access_mode: str = "read_only", ) -> QueryExecutionResult: """Guard and execute SQL synchronously. Args: sql: SQL text to execute. max_rows: Optional row-limit override. access_mode: Requested execution mode, either `read_only` or `writable`. Returns: QueryExecutionResult: Guard and execution result. """ guard_result = self._guard.guard( sql, max_rows=max_rows, access_mode=access_mode, ) if not guard_result.allowed or guard_result.normalized_sql is None: return QueryExecutionResult( mode="sync", guard=guard_result, summary=guard_result.summary, ) started_at = perf_counter() engine = self._resolve_sync_engine(access_mode=guard_result.access_mode) with engine.begin() as connection: result = connection.execute(text(guard_result.normalized_sql)) rows_affected = self._rows_affected(result) if result.returns_rows: columns = list(result.keys()) rows = [ {str(key): to_jsonable(value) for key, value in row.items()} for row in result.mappings().all() ] else: columns = [] rows = [] duration_ms = round((perf_counter() - started_at) * 1000, 3) return QueryExecutionResult( mode="sync", guard=guard_result, columns=columns, rows=rows, row_count=len(rows), rows_affected=rows_affected, truncated=bool( guard_result.max_rows is not None and len(rows) >= guard_result.max_rows ), duration_ms=duration_ms, summary=self._summarize_result( mode="sync", row_count=len(rows), rows_affected=rows_affected, duration_ms=duration_ms, guard_summary=guard_result.summary, ), )
[docs] async def run_async( self, sql: str, *, max_rows: int | None = None, access_mode: str = "read_only", ) -> QueryExecutionResult: """Guard and execute SQL asynchronously. Args: sql: SQL text to execute. max_rows: Optional row-limit override. access_mode: Requested execution mode, either `read_only` or `writable`. Returns: QueryExecutionResult: Guard and execution result. Raises: ConfigurationError: If no async engine is configured. """ if self._async_engine is None and access_mode != "writable": raise ConfigurationError("async query execution is not configured") if self._write_async_engine is None and access_mode == "writable": raise ConfigurationError("async writable query execution is not configured") guard_result = self._guard.guard( sql, max_rows=max_rows, access_mode=access_mode, ) if not guard_result.allowed or guard_result.normalized_sql is None: return QueryExecutionResult( mode="async", guard=guard_result, summary=guard_result.summary, ) started_at = perf_counter() engine = self._resolve_async_engine(access_mode=guard_result.access_mode) async with engine.begin() as connection: result = await connection.execute(text(guard_result.normalized_sql)) rows_affected = self._rows_affected(result) if result.returns_rows: columns = list(result.keys()) rows = [ {str(key): to_jsonable(value) for key, value in row.items()} for row in result.mappings().all() ] else: columns = [] rows = [] duration_ms = round((perf_counter() - started_at) * 1000, 3) return QueryExecutionResult( mode="async", guard=guard_result, columns=columns, rows=rows, row_count=len(rows), rows_affected=rows_affected, truncated=bool( guard_result.max_rows is not None and len(rows) >= guard_result.max_rows ), duration_ms=duration_ms, summary=self._summarize_result( mode="async", row_count=len(rows), rows_affected=rows_affected, duration_ms=duration_ms, guard_summary=guard_result.summary, ), )
def _summarize_result( self, *, mode: str, row_count: int, rows_affected: int | None, duration_ms: float, guard_summary: str | None, ) -> str: """Build a short human-readable summary for one execution.""" prefix = guard_summary or "Query executed." if rows_affected is not None and row_count == 0: return ( f"{prefix} Execution mode: {mode}. Affected {rows_affected} rows in " f"{duration_ms} ms." ) return ( f"{prefix} Execution mode: {mode}. Returned {row_count} rows in " f"{duration_ms} ms." ) def _resolve_sync_engine(self, *, access_mode: str) -> Engine: """Resolve the sync engine for the requested access mode.""" if access_mode == "writable": if self._write_engine is None: raise ConfigurationError("writable query execution is not configured") return self._write_engine return self._engine def _resolve_async_engine(self, *, access_mode: str) -> AsyncEngine: """Resolve the async engine for the requested access mode.""" if access_mode == "writable": if self._write_async_engine is None: raise ConfigurationError( "async writable query execution is not configured" ) return self._write_async_engine if self._async_engine is None: raise ConfigurationError("async query execution is not configured") return self._async_engine @staticmethod def _rows_affected(result: object) -> int | None: """Return rows-affected metadata when provided by SQLAlchemy.""" rowcount = getattr(result, "rowcount", None) if isinstance(rowcount, int) and rowcount >= 0: return rowcount return None