- Rename project to stealergram throughout - Add pyproject.toml (replaces requirements.txt split, folds pytest.ini) - Replace all em-dashes with hyphens across all source files Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
159 lines
5.2 KiB
Python
159 lines
5.2 KiB
Python
"""
|
|
web/db.py - SQLite user store for the web frontend.
|
|
|
|
Tables:
|
|
users - credentials + role + active flag
|
|
refresh_tokens - JTI-indexed refresh token revocation list
|
|
|
|
Bootstrap: on first init, creates a superadmin from WEB_ADMIN_USER / WEB_ADMIN_PASS
|
|
env vars (required only on first run if the DB doesn't exist yet).
|
|
"""
|
|
|
|
import os
|
|
import sqlite3
|
|
import uuid
|
|
from contextlib import contextmanager
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
from web.auth import hash_password
|
|
|
|
DB_FILE = Path("./data/web.db")
|
|
|
|
_SCHEMA = """
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id TEXT PRIMARY KEY,
|
|
username TEXT UNIQUE NOT NULL,
|
|
password_hash TEXT NOT NULL,
|
|
role TEXT NOT NULL CHECK(role IN ('superadmin','admin','reader')),
|
|
created_at TEXT NOT NULL,
|
|
is_active INTEGER NOT NULL DEFAULT 1
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
|
jti TEXT PRIMARY KEY,
|
|
user_id TEXT NOT NULL,
|
|
expires_at TEXT NOT NULL,
|
|
revoked INTEGER NOT NULL DEFAULT 0
|
|
);
|
|
"""
|
|
|
|
|
|
@contextmanager
|
|
def get_conn():
|
|
DB_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
conn = sqlite3.connect(DB_FILE, check_same_thread=False)
|
|
conn.row_factory = sqlite3.Row
|
|
try:
|
|
yield conn
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def init_db() -> None:
|
|
"""Create schema and bootstrap superadmin on first run."""
|
|
with get_conn() as conn:
|
|
conn.executescript(_SCHEMA)
|
|
|
|
# Bootstrap superadmin only if the users table is empty.
|
|
row = conn.execute("SELECT COUNT(*) FROM users").fetchone()
|
|
if row[0] == 0:
|
|
admin_user = os.environ.get("WEB_ADMIN_USER", "admin")
|
|
admin_pass = os.environ.get("WEB_ADMIN_PASS")
|
|
if not admin_pass:
|
|
raise RuntimeError(
|
|
"WEB_ADMIN_PASS env var is required on first run to bootstrap the superadmin. "
|
|
"Add WEB_ADMIN_PASS=<password> (and optionally WEB_ADMIN_USER=<username>) "
|
|
"to your .env file, then restart."
|
|
)
|
|
conn.execute(
|
|
"INSERT INTO users (id, username, password_hash, role, created_at) VALUES (?,?,?,?,?)",
|
|
(
|
|
str(uuid.uuid4()),
|
|
admin_user,
|
|
hash_password(admin_pass),
|
|
"superadmin",
|
|
datetime.now(timezone.utc).isoformat(),
|
|
),
|
|
)
|
|
|
|
|
|
# ─── User queries ─────────────────────────────────────────────────────────────
|
|
|
|
def get_user_by_username(username: str) -> sqlite3.Row | None:
|
|
with get_conn() as conn:
|
|
return conn.execute(
|
|
"SELECT * FROM users WHERE username = ? AND is_active = 1", (username,)
|
|
).fetchone()
|
|
|
|
|
|
def get_user_by_id(user_id: str) -> sqlite3.Row | None:
|
|
with get_conn() as conn:
|
|
return conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone()
|
|
|
|
|
|
def list_users() -> list[sqlite3.Row]:
|
|
with get_conn() as conn:
|
|
return conn.execute("SELECT * FROM users ORDER BY created_at").fetchall()
|
|
|
|
|
|
def create_user(username: str, password: str, role: str) -> str:
|
|
user_id = str(uuid.uuid4())
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
with get_conn() as conn:
|
|
conn.execute(
|
|
"INSERT INTO users (id, username, password_hash, role, created_at) VALUES (?,?,?,?,?)",
|
|
(user_id, username, hash_password(password), role, now),
|
|
)
|
|
return user_id
|
|
|
|
|
|
def update_user(user_id: str, **fields) -> None:
|
|
"""Update arbitrary user fields. Hashes password if provided."""
|
|
if "password" in fields:
|
|
fields["password_hash"] = hash_password(fields.pop("password"))
|
|
if not fields:
|
|
return
|
|
cols = ", ".join(f"{k} = ?" for k in fields)
|
|
with get_conn() as conn:
|
|
conn.execute(
|
|
f"UPDATE users SET {cols} WHERE id = ?",
|
|
(*fields.values(), user_id),
|
|
)
|
|
|
|
|
|
def deactivate_user(user_id: str) -> None:
|
|
with get_conn() as conn:
|
|
conn.execute("UPDATE users SET is_active = 0 WHERE id = ?", (user_id,))
|
|
|
|
|
|
# ─── Refresh token queries ────────────────────────────────────────────────────
|
|
|
|
def store_refresh_token(jti: str, user_id: str, expires_at: datetime) -> None:
|
|
with get_conn() as conn:
|
|
conn.execute(
|
|
"INSERT INTO refresh_tokens (jti, user_id, expires_at) VALUES (?,?,?)",
|
|
(jti, user_id, expires_at.isoformat()),
|
|
)
|
|
|
|
|
|
def is_refresh_token_valid(jti: str) -> bool:
|
|
with get_conn() as conn:
|
|
row = conn.execute(
|
|
"SELECT revoked, expires_at FROM refresh_tokens WHERE jti = ?", (jti,)
|
|
).fetchone()
|
|
if row is None:
|
|
return False
|
|
if row["revoked"]:
|
|
return False
|
|
expires = datetime.fromisoformat(row["expires_at"])
|
|
return datetime.now(timezone.utc) < expires
|
|
|
|
|
|
def revoke_refresh_token(jti: str) -> None:
|
|
with get_conn() as conn:
|
|
conn.execute(
|
|
"UPDATE refresh_tokens SET revoked = 1 WHERE jti = ?", (jti,)
|
|
)
|