107 lines
3.1 KiB
Python
107 lines
3.1 KiB
Python
"""Single-use backup codes for Desk 2FA (Spec 004 extension)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import secrets
|
|
import sqlite3
|
|
from datetime import datetime, timezone
|
|
|
|
BACKUP_CODE_COUNT = 10
|
|
_CHARS = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
|
|
|
|
|
def _ensure_column(conn: sqlite3.Connection, table: str, column: str, ddl: str) -> None:
|
|
cols = {row[1] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()}
|
|
if column not in cols:
|
|
conn.execute(f"ALTER TABLE {table} ADD COLUMN {ddl}")
|
|
|
|
|
|
def init_backup_schema(conn: sqlite3.Connection) -> None:
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS desk_backup_codes (
|
|
id INTEGER PRIMARY KEY,
|
|
username TEXT NOT NULL,
|
|
code_hash TEXT NOT NULL,
|
|
used_at TEXT,
|
|
created_at TEXT NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
conn.execute(
|
|
"CREATE INDEX IF NOT EXISTS idx_backup_codes_user ON desk_backup_codes(username)"
|
|
)
|
|
for col, ddl in [
|
|
("recovery_email_otp", "recovery_email_otp TEXT"),
|
|
("recovery_email_otp_expires", "recovery_email_otp_expires TEXT"),
|
|
]:
|
|
_ensure_column(conn, "desk_users", col, ddl)
|
|
|
|
|
|
def _normalize_code(code: str) -> str:
|
|
return code.strip().upper().replace(" ", "").replace("-", "")
|
|
|
|
|
|
def _format_code(raw: str) -> str:
|
|
return f"{raw[:4]}-{raw[4:]}"
|
|
|
|
|
|
def generate_backup_codes(count: int = BACKUP_CODE_COUNT) -> list[str]:
|
|
codes: list[str] = []
|
|
seen: set[str] = set()
|
|
while len(codes) < count:
|
|
raw = "".join(secrets.choice(_CHARS) for _ in range(8))
|
|
formatted = _format_code(raw)
|
|
if formatted not in seen:
|
|
seen.add(formatted)
|
|
codes.append(formatted)
|
|
return codes
|
|
|
|
|
|
def hash_backup_code(username: str, code: str) -> str:
|
|
norm = _normalize_code(code)
|
|
return hashlib.sha256(f"{username}:{norm}".encode()).hexdigest()
|
|
|
|
|
|
def store_backup_codes(conn: sqlite3.Connection, username: str, codes: list[str]) -> None:
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
conn.execute("DELETE FROM desk_backup_codes WHERE username = ?", (username,))
|
|
for code in codes:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO desk_backup_codes (username, code_hash, created_at)
|
|
VALUES (?, ?, ?)
|
|
""",
|
|
(username, hash_backup_code(username, code), now),
|
|
)
|
|
|
|
|
|
def consume_backup_code(conn: sqlite3.Connection, username: str, code: str) -> bool:
|
|
h = hash_backup_code(username, code)
|
|
row = conn.execute(
|
|
"""
|
|
SELECT id FROM desk_backup_codes
|
|
WHERE username = ? AND code_hash = ? AND used_at IS NULL
|
|
""",
|
|
(username, h),
|
|
).fetchone()
|
|
if not row:
|
|
return False
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
conn.execute(
|
|
"UPDATE desk_backup_codes SET used_at = ? WHERE id = ?",
|
|
(now, row["id"]),
|
|
)
|
|
return True
|
|
|
|
|
|
def count_remaining(conn: sqlite3.Connection, username: str) -> int:
|
|
row = conn.execute(
|
|
"""
|
|
SELECT COUNT(*) c FROM desk_backup_codes
|
|
WHERE username = ? AND used_at IS NULL
|
|
""",
|
|
(username,),
|
|
).fetchone()
|
|
return int(row["c"]) if row else 0
|