Files
stealergram/web/db.py
anti 741e6bb0d3 Rename to stealergram, add pyproject.toml, purge em-dashes
- Rename project to stealergram throughout
- Add pyproject.toml (replaces requirements.txt split, folds pytest.ini)
- Replace all em-dashes with hyphens across all source files

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-19 10:06:30 -04:00

159 lines
5.2 KiB
Python

"""
web/db.py - SQLite user store for the web frontend.
Tables:
users - credentials + role + active flag
refresh_tokens - JTI-indexed refresh token revocation list
Bootstrap: on first init, creates a superadmin from WEB_ADMIN_USER / WEB_ADMIN_PASS
env vars (required only on first run if the DB doesn't exist yet).
"""
import os
import sqlite3
import uuid
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
from web.auth import hash_password
DB_FILE = Path("./data/web.db")
_SCHEMA = """
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
role TEXT NOT NULL CHECK(role IN ('superadmin','admin','reader')),
created_at TEXT NOT NULL,
is_active INTEGER NOT NULL DEFAULT 1
);
CREATE TABLE IF NOT EXISTS refresh_tokens (
jti TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
expires_at TEXT NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0
);
"""
@contextmanager
def get_conn():
DB_FILE.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(DB_FILE, check_same_thread=False)
conn.row_factory = sqlite3.Row
try:
yield conn
conn.commit()
finally:
conn.close()
def init_db() -> None:
"""Create schema and bootstrap superadmin on first run."""
with get_conn() as conn:
conn.executescript(_SCHEMA)
# Bootstrap superadmin only if the users table is empty.
row = conn.execute("SELECT COUNT(*) FROM users").fetchone()
if row[0] == 0:
admin_user = os.environ.get("WEB_ADMIN_USER", "admin")
admin_pass = os.environ.get("WEB_ADMIN_PASS")
if not admin_pass:
raise RuntimeError(
"WEB_ADMIN_PASS env var is required on first run to bootstrap the superadmin. "
"Add WEB_ADMIN_PASS=<password> (and optionally WEB_ADMIN_USER=<username>) "
"to your .env file, then restart."
)
conn.execute(
"INSERT INTO users (id, username, password_hash, role, created_at) VALUES (?,?,?,?,?)",
(
str(uuid.uuid4()),
admin_user,
hash_password(admin_pass),
"superadmin",
datetime.now(timezone.utc).isoformat(),
),
)
# ─── User queries ─────────────────────────────────────────────────────────────
def get_user_by_username(username: str) -> sqlite3.Row | None:
with get_conn() as conn:
return conn.execute(
"SELECT * FROM users WHERE username = ? AND is_active = 1", (username,)
).fetchone()
def get_user_by_id(user_id: str) -> sqlite3.Row | None:
with get_conn() as conn:
return conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone()
def list_users() -> list[sqlite3.Row]:
with get_conn() as conn:
return conn.execute("SELECT * FROM users ORDER BY created_at").fetchall()
def create_user(username: str, password: str, role: str) -> str:
user_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
with get_conn() as conn:
conn.execute(
"INSERT INTO users (id, username, password_hash, role, created_at) VALUES (?,?,?,?,?)",
(user_id, username, hash_password(password), role, now),
)
return user_id
def update_user(user_id: str, **fields) -> None:
"""Update arbitrary user fields. Hashes password if provided."""
if "password" in fields:
fields["password_hash"] = hash_password(fields.pop("password"))
if not fields:
return
cols = ", ".join(f"{k} = ?" for k in fields)
with get_conn() as conn:
conn.execute(
f"UPDATE users SET {cols} WHERE id = ?",
(*fields.values(), user_id),
)
def deactivate_user(user_id: str) -> None:
with get_conn() as conn:
conn.execute("UPDATE users SET is_active = 0 WHERE id = ?", (user_id,))
# ─── Refresh token queries ────────────────────────────────────────────────────
def store_refresh_token(jti: str, user_id: str, expires_at: datetime) -> None:
with get_conn() as conn:
conn.execute(
"INSERT INTO refresh_tokens (jti, user_id, expires_at) VALUES (?,?,?)",
(jti, user_id, expires_at.isoformat()),
)
def is_refresh_token_valid(jti: str) -> bool:
with get_conn() as conn:
row = conn.execute(
"SELECT revoked, expires_at FROM refresh_tokens WHERE jti = ?", (jti,)
).fetchone()
if row is None:
return False
if row["revoked"]:
return False
expires = datetime.fromisoformat(row["expires_at"])
return datetime.now(timezone.utc) < expires
def revoke_refresh_token(jti: str) -> None:
with get_conn() as conn:
conn.execute(
"UPDATE refresh_tokens SET revoked = 1 WHERE jti = ?", (jti,)
)