124 lines
4.3 KiB
Python
124 lines
4.3 KiB
Python
"""Middleware auditoria de inputs — espelho heurístico do VM122 (Spec 021)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
from typing import Awaitable, Callable
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse, Response
|
|
|
|
SQLI_PATTERNS = [
|
|
re.compile(r"'\s*or\s+", re.I),
|
|
re.compile(r"union\s+select", re.I),
|
|
re.compile(r";\s*drop\s+", re.I),
|
|
re.compile(r"1\s*=\s*1", re.I),
|
|
re.compile(r"--\s*$"),
|
|
]
|
|
XSS_PATTERNS = [
|
|
re.compile(r"<\s*script", re.I),
|
|
re.compile(r"javascript\s*:", re.I),
|
|
re.compile(r"onerror\s*=", re.I),
|
|
re.compile(r"onload\s*=", re.I),
|
|
]
|
|
PATH_PATTERNS = [
|
|
re.compile(r"\.\./"),
|
|
re.compile(r"%2e%2e", re.I),
|
|
]
|
|
|
|
SKIP_PATHS = frozenset({"/health", "/metrics", "/favicon.ico"})
|
|
AUDIT_FIELDS = frozenset({"domain", "email", "company", "subdomain", "hostname", "mx", "txt"})
|
|
|
|
OnBlockCallback = Callable[[str, str, str, dict], Awaitable[None] | None]
|
|
|
|
|
|
def audit_value(value: str, *, field: str = "") -> dict:
|
|
text = (value or "").strip()
|
|
if not text:
|
|
return {"ok": True}
|
|
if len(text) > 2000:
|
|
return {"ok": False, "reason": "oversize", "severity": "high", "field": field}
|
|
for pat in SQLI_PATTERNS:
|
|
if pat.search(text):
|
|
return {"ok": False, "reason": "sql_injection_pattern", "severity": "high", "field": field}
|
|
for pat in XSS_PATTERNS:
|
|
if pat.search(text):
|
|
return {"ok": False, "reason": "xss_pattern", "severity": "high", "field": field}
|
|
for pat in PATH_PATTERNS:
|
|
if pat.search(text):
|
|
return {"ok": False, "reason": "path_traversal", "severity": "high", "field": field}
|
|
return {"ok": True}
|
|
|
|
|
|
def _extract_strings(obj, prefix: str = "") -> list[tuple[str, str]]:
|
|
out: list[tuple[str, str]] = []
|
|
if isinstance(obj, dict):
|
|
for k, v in obj.items():
|
|
key = f"{prefix}.{k}" if prefix else str(k)
|
|
if isinstance(v, str):
|
|
out.append((key, v))
|
|
elif isinstance(v, (dict, list)):
|
|
out.extend(_extract_strings(v, key))
|
|
elif isinstance(obj, list):
|
|
for i, v in enumerate(obj):
|
|
key = f"{prefix}[{i}]"
|
|
if isinstance(v, str):
|
|
out.append((key, v))
|
|
elif isinstance(v, (dict, list)):
|
|
out.extend(_extract_strings(v, key))
|
|
return out
|
|
|
|
|
|
class SecurityAuditMiddleware(BaseHTTPMiddleware):
|
|
def __init__(self, app, on_block: OnBlockCallback | None = None):
|
|
super().__init__(app)
|
|
self.on_block = on_block
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
if request.method not in ("POST", "PUT", "PATCH"):
|
|
return await call_next(request)
|
|
if request.url.path in SKIP_PATHS:
|
|
return await call_next(request)
|
|
|
|
body_bytes = await request.body()
|
|
if not body_bytes:
|
|
return await call_next(request)
|
|
|
|
try:
|
|
payload = json.loads(body_bytes)
|
|
except json.JSONDecodeError:
|
|
return await call_next(request)
|
|
|
|
session_id = None
|
|
domain = None
|
|
if isinstance(payload, dict):
|
|
session_id = payload.get("session_id") or payload.get("sessionId")
|
|
domain = payload.get("domain")
|
|
|
|
for field, value in _extract_strings(payload):
|
|
base_field = field.split(".")[-1].split("[")[0]
|
|
if base_field not in AUDIT_FIELDS and len(value) < 8:
|
|
continue
|
|
result = audit_value(value, field=field)
|
|
if not result.get("ok"):
|
|
if self.on_block:
|
|
maybe = self.on_block(
|
|
"security.input_blocked",
|
|
session_id or "",
|
|
domain or "",
|
|
{**result, "endpoint": request.url.path},
|
|
)
|
|
if maybe is not None:
|
|
await maybe
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={"error": "input_blocked", "reason": result.get("reason"), "field": field},
|
|
)
|
|
|
|
async def receive():
|
|
return {"type": "http.request", "body": body_bytes, "more_body": False}
|
|
|
|
request = Request(request.scope, receive)
|
|
return await call_next(request)
|