Source code for sqldbagent.safety.guard

"""SQL query guard service."""

from __future__ import annotations

from typing import Any

from sqldbagent.core.config import SafetySettings
from sqldbagent.core.enums import Dialect
from sqldbagent.safety.models import QueryGuardResult
from sqldbagent.safety.policies import should_apply_row_limit, to_sqlglot_dialect


[docs] class QueryGuardService: """Guard SQL through AST inspection and normalization."""
[docs] def __init__(self, policy: SafetySettings, dialect: Dialect) -> None: """Initialize the query guard. Args: policy: Safety policy settings. dialect: Datasource dialect. """ self._policy = policy self._dialect = dialect self._sqlglot_dialect = to_sqlglot_dialect(dialect)
[docs] def lint(self, sql: str) -> QueryGuardResult: """Parse and normalize SQL without applying guard rewrites. Args: sql: SQL text to lint. Returns: QueryGuardResult: Lint result. """ return self._evaluate(sql, apply_guard=False)
[docs] def guard(self, sql: str, *, max_rows: int | None = None) -> QueryGuardResult: """Parse, validate, and normalize SQL under the active policy. Args: sql: SQL text to guard. max_rows: Optional row-limit override for this evaluation. Returns: QueryGuardResult: Guard result. """ return self._evaluate(sql, apply_guard=True, max_rows=max_rows)
def _evaluate( self, sql: str, *, apply_guard: bool, max_rows: int | None = None, ) -> QueryGuardResult: """Evaluate a SQL statement. Args: sql: SQL text to evaluate. apply_guard: Whether guard rewrites should be applied. max_rows: Optional row-limit override. Returns: QueryGuardResult: Evaluation result. """ sqlglot = self._load_sqlglot() exp = sqlglot.exp policy = ( self._policy if max_rows is None else self._policy.model_copy(update={"max_rows": max_rows}) ) try: statements = [ statement for statement in sqlglot.parse(sql, dialect=self._sqlglot_dialect) if statement is not None ] except sqlglot.errors.ParseError as exc: return QueryGuardResult( allowed=False, dialect=self._dialect.value, original_sql=sql, reasons=[str(exc)], summary="Query failed to parse.", ) if len(statements) != 1: return QueryGuardResult( allowed=False, dialect=self._dialect.value, original_sql=sql, reasons=["exactly one SQL statement is required"], summary="Query rejected because multiple statements were provided.", ) statement = statements[0] statement_type = statement.__class__.__name__.upper() referenced_schemas, referenced_tables = self._collect_references( statement, exp=exp, ) reasons = self._collect_reasons( statement, exp=exp, policy=policy, referenced_schemas=referenced_schemas, ) if reasons: return QueryGuardResult( allowed=False, dialect=self._dialect.value, original_sql=sql, statement_type=statement_type, normalized_sql=statement.sql(dialect=self._sqlglot_dialect), max_rows=policy.max_rows, referenced_schemas=referenced_schemas, referenced_tables=referenced_tables, reasons=reasons, summary=self._summarize_result( allowed=False, statement_type=statement_type, referenced_tables=referenced_tables, reasons=reasons, ), ) guarded = statement.copy() row_limit_applied = False if apply_guard: limit_expression = guarded.args.get("limit") has_limit = limit_expression is not None if should_apply_row_limit(policy, has_limit): guarded = guarded.limit(policy.max_rows) row_limit_applied = True elif has_limit: current_limit = self._extract_limit_value(limit_expression, exp=exp) if current_limit is None or current_limit > policy.max_rows: guarded = guarded.limit(policy.max_rows) row_limit_applied = True return QueryGuardResult( allowed=True, dialect=self._dialect.value, original_sql=sql, statement_type=statement_type, normalized_sql=guarded.sql(dialect=self._sqlglot_dialect), row_limit_applied=row_limit_applied, max_rows=policy.max_rows, referenced_schemas=referenced_schemas, referenced_tables=referenced_tables, summary=self._summarize_result( allowed=True, statement_type=statement_type, referenced_tables=referenced_tables, reasons=[], ), ) def _collect_reasons( self, statement: Any, *, exp: Any, policy: SafetySettings, referenced_schemas: list[str], ) -> list[str]: """Collect validation failures for a statement. Args: statement: SQLGlot expression tree. exp: SQLGlot expressions module. policy: Effective safety policy. referenced_schemas: Referenced schemas discovered in the statement. Returns: list[str]: Validation failure reasons. """ reasons: list[str] = [] if policy.read_only and not isinstance(statement, exp.Query): reasons.append("only read-only query statements are allowed") if policy.allowed_schemas: denied_schemas = sorted( { schema_name for schema_name in referenced_schemas if schema_name not in policy.allowed_schemas } ) if denied_schemas: reasons.append( "query references disallowed schemas: " + ", ".join(denied_schemas) ) disallowed_node_names = [ "Delete", "Update", "Insert", "Merge", "Create", "Drop", "Alter", "Command", "Copy", "Grant", "Revoke", "TruncateTable", "Use", "Call", ] for node_name in disallowed_node_names: node_type = getattr(exp, node_name, None) if node_type is not None and statement.find(node_type): reasons.append( f"disallowed SQL operation detected: {node_name.lower()}" ) return reasons def _collect_references( self, statement: Any, *, exp: Any ) -> tuple[list[str], list[str]]: """Collect referenced schemas and tables from a SQLGlot statement. Args: statement: SQLGlot expression tree. exp: SQLGlot expressions module. Returns: tuple[list[str], list[str]]: Referenced schemas and tables. """ referenced_schemas: set[str] = set() referenced_tables: set[str] = set() for table in statement.find_all(exp.Table): schema_name = table.db table_name = table.name if schema_name: referenced_schemas.add(schema_name) referenced_tables.add(f"{schema_name}.{table_name}") else: referenced_tables.add(table_name) return sorted(referenced_schemas), sorted(referenced_tables) def _extract_limit_value(self, limit_expression: Any, *, exp: Any) -> int | None: """Extract a literal row limit when present. Args: limit_expression: SQLGlot limit expression. exp: SQLGlot expressions module. Returns: int | None: Literal limit value when it can be determined. """ if limit_expression is None: return None expression = getattr(limit_expression, "expression", None) if isinstance(expression, exp.Literal) and expression.is_int: return int(expression.this) return None def _load_sqlglot(self) -> Any: """Import SQLGlot lazily. Returns: Any: SQLGlot module. """ from sqldbagent.adapters.shared import require_dependency return require_dependency("sqlglot", "sqlglot") def _summarize_result( self, *, allowed: bool, statement_type: str | None, referenced_tables: list[str], reasons: list[str], ) -> str: """Build a short human-readable summary for a guard evaluation.""" table_text = ", ".join(referenced_tables) if referenced_tables else "no tables" if allowed: return ( f"{statement_type or 'statement'} accepted for {self._dialect.value}; " f"references {table_text}." ) return ( f"{statement_type or 'statement'} rejected for {self._dialect.value}: " + "; ".join(reasons) )