"""Middleware auditoria de inputs — espelho heurístico do VM122 (Spec 021).""" from __future__ import annotations import json import re from typing import Awaitable, Callable from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response SQLI_PATTERNS = [ re.compile(r"'\s*or\s+", re.I), re.compile(r"union\s+select", re.I), re.compile(r";\s*drop\s+", re.I), re.compile(r"1\s*=\s*1", re.I), re.compile(r"--\s*$"), ] XSS_PATTERNS = [ re.compile(r"<\s*script", re.I), re.compile(r"javascript\s*:", re.I), re.compile(r"onerror\s*=", re.I), re.compile(r"onload\s*=", re.I), ] PATH_PATTERNS = [ re.compile(r"\.\./"), re.compile(r"%2e%2e", re.I), ] SKIP_PATHS = frozenset({"/health", "/metrics", "/favicon.ico"}) AUDIT_FIELDS = frozenset({"domain", "email", "company", "subdomain", "hostname", "mx", "txt"}) OnBlockCallback = Callable[[str, str, str, dict], Awaitable[None] | None] def audit_value(value: str, *, field: str = "") -> dict: text = (value or "").strip() if not text: return {"ok": True} if len(text) > 2000: return {"ok": False, "reason": "oversize", "severity": "high", "field": field} for pat in SQLI_PATTERNS: if pat.search(text): return {"ok": False, "reason": "sql_injection_pattern", "severity": "high", "field": field} for pat in XSS_PATTERNS: if pat.search(text): return {"ok": False, "reason": "xss_pattern", "severity": "high", "field": field} for pat in PATH_PATTERNS: if pat.search(text): return {"ok": False, "reason": "path_traversal", "severity": "high", "field": field} return {"ok": True} def _extract_strings(obj, prefix: str = "") -> list[tuple[str, str]]: out: list[tuple[str, str]] = [] if isinstance(obj, dict): for k, v in obj.items(): key = f"{prefix}.{k}" if prefix else str(k) if isinstance(v, str): out.append((key, v)) elif isinstance(v, (dict, list)): out.extend(_extract_strings(v, key)) elif isinstance(obj, list): for i, v in enumerate(obj): key = f"{prefix}[{i}]" if isinstance(v, str): out.append((key, v)) elif isinstance(v, (dict, list)): out.extend(_extract_strings(v, key)) return out class SecurityAuditMiddleware(BaseHTTPMiddleware): def __init__(self, app, on_block: OnBlockCallback | None = None): super().__init__(app) self.on_block = on_block async def dispatch(self, request: Request, call_next) -> Response: if request.method not in ("POST", "PUT", "PATCH"): return await call_next(request) if request.url.path in SKIP_PATHS: return await call_next(request) body_bytes = await request.body() if not body_bytes: return await call_next(request) try: payload = json.loads(body_bytes) except json.JSONDecodeError: return await call_next(request) session_id = None domain = None if isinstance(payload, dict): session_id = payload.get("session_id") or payload.get("sessionId") domain = payload.get("domain") for field, value in _extract_strings(payload): base_field = field.split(".")[-1].split("[")[0] if base_field not in AUDIT_FIELDS and len(value) < 8: continue result = audit_value(value, field=field) if not result.get("ok"): if self.on_block: maybe = self.on_block( "security.input_blocked", session_id or "", domain or "", {**result, "endpoint": request.url.path}, ) if maybe is not None: await maybe return JSONResponse( status_code=400, content={"error": "input_blocked", "reason": result.get("reason"), "field": field}, ) async def receive(): return {"type": "http.request", "body": body_bytes, "more_body": False} request = Request(request.scope, receive) return await call_next(request)