"""Single-use backup codes for Desk 2FA (Spec 004 extension).""" from __future__ import annotations import hashlib import secrets import sqlite3 from datetime import datetime, timezone BACKUP_CODE_COUNT = 10 _CHARS = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" def _ensure_column(conn: sqlite3.Connection, table: str, column: str, ddl: str) -> None: cols = {row[1] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()} if column not in cols: conn.execute(f"ALTER TABLE {table} ADD COLUMN {ddl}") def init_backup_schema(conn: sqlite3.Connection) -> None: conn.execute( """ CREATE TABLE IF NOT EXISTS desk_backup_codes ( id INTEGER PRIMARY KEY, username TEXT NOT NULL, code_hash TEXT NOT NULL, used_at TEXT, created_at TEXT NOT NULL ) """ ) conn.execute( "CREATE INDEX IF NOT EXISTS idx_backup_codes_user ON desk_backup_codes(username)" ) for col, ddl in [ ("recovery_email_otp", "recovery_email_otp TEXT"), ("recovery_email_otp_expires", "recovery_email_otp_expires TEXT"), ]: _ensure_column(conn, "desk_users", col, ddl) def _normalize_code(code: str) -> str: return code.strip().upper().replace(" ", "").replace("-", "") def _format_code(raw: str) -> str: return f"{raw[:4]}-{raw[4:]}" def generate_backup_codes(count: int = BACKUP_CODE_COUNT) -> list[str]: codes: list[str] = [] seen: set[str] = set() while len(codes) < count: raw = "".join(secrets.choice(_CHARS) for _ in range(8)) formatted = _format_code(raw) if formatted not in seen: seen.add(formatted) codes.append(formatted) return codes def hash_backup_code(username: str, code: str) -> str: norm = _normalize_code(code) return hashlib.sha256(f"{username}:{norm}".encode()).hexdigest() def store_backup_codes(conn: sqlite3.Connection, username: str, codes: list[str]) -> None: now = datetime.now(timezone.utc).isoformat() conn.execute("DELETE FROM desk_backup_codes WHERE username = ?", (username,)) for code in codes: conn.execute( """ INSERT INTO desk_backup_codes (username, code_hash, created_at) VALUES (?, ?, ?) """, (username, hash_backup_code(username, code), now), ) def consume_backup_code(conn: sqlite3.Connection, username: str, code: str) -> bool: h = hash_backup_code(username, code) row = conn.execute( """ SELECT id FROM desk_backup_codes WHERE username = ? AND code_hash = ? AND used_at IS NULL """, (username, h), ).fetchone() if not row: return False now = datetime.now(timezone.utc).isoformat() conn.execute( "UPDATE desk_backup_codes SET used_at = ? WHERE id = ?", (now, row["id"]), ) return True def count_remaining(conn: sqlite3.Connection, username: str) -> int: row = conn.execute( """ SELECT COUNT(*) c FROM desk_backup_codes WHERE username = ? AND used_at IS NULL """, (username,), ).fetchone() return int(row["c"]) if row else 0