""" Tests for utils/database.py - SQLite persistence layer. Each test gets an isolated in-memory-equivalent DB via the `isolated_db` fixture so tests never touch data/hits.db. """ import pytest import utils.database as db_module from utils.scorer import ScoredHit, CRITICAL, HIGH, MEDIUM, LOW def make_hit(severity=LOW, url="testcorp.com", username="user", password="pass", raw=None): """Build a minimal ScoredHit for insertion tests.""" scores = {CRITICAL: 40, HIGH: 30, MEDIUM: 20, LOW: 10} return ScoredHit( raw=raw or f"{url}|{username}|{password}", severity=severity, score=scores[severity], reasons=["Test reason"], url=url, username=username, password=password, ) @pytest.fixture(autouse=True) def isolated_db(tmp_path, monkeypatch): monkeypatch.setattr(db_module, "DB_FILE", tmp_path / "test_hits.db") db_module.init_db() # ─── init_db ───────────────────────────────────────────────────────────────── def test_init_db_is_idempotent(): db_module.init_db() db_module.init_db() # must not raise # ─── insert_hits ────────────────────────────────────────────────────────────── def test_insert_returns_correct_row_count(): hits = [make_hit(), make_hit(severity=CRITICAL)] count = db_module.insert_hits(hits, source="testchan", filename="combo.txt") assert count == 2 def test_insert_stores_all_fields(): hit = make_hit(severity=HIGH, url="intranet.testcorp.com", username="jdoe", password="s3cr3t") db_module.insert_hits([hit], source="mychan", filename="creds.zip") rows = db_module.search("jdoe") assert len(rows) == 1 row = rows[0] assert row["url"] == "intranet.testcorp.com" assert row["username"] == "jdoe" assert row["password"] == "s3cr3t" assert row["severity"] == HIGH assert row["score"] == 30 assert row["source"] == "mychan" assert row["filename"] == "creds.zip" assert row["seen_before"] == 0 def test_insert_seen_before_flag(): hit = make_hit() db_module.insert_hits([hit], source="chan", filename="f.txt", seen_before=True) rows = db_module.search("testcorp") assert rows[0]["seen_before"] == 1 # ─── search ─────────────────────────────────────────────────────────────────── def test_search_finds_by_username(): db_module.insert_hits([make_hit(username="jdoe@testcorp.com")], source="c", filename="f.txt") results = db_module.search("jdoe") assert len(results) == 1 assert results[0]["username"] == "jdoe@testcorp.com" def test_search_finds_by_url(): db_module.insert_hits([make_hit(url="admin.testcorp.com")], source="c", filename="f.txt") results = db_module.search("admin.testcorp") assert len(results) == 1 def test_search_finds_by_raw(): db_module.insert_hits([make_hit(raw="raw_unique_token_xyz")], source="c", filename="f.txt") results = db_module.search("unique_token") assert len(results) == 1 def test_search_returns_empty_for_no_match(): db_module.insert_hits([make_hit()], source="c", filename="f.txt") assert db_module.search("zzznomatch_xyz") == [] def test_search_sorted_by_score_descending(): db_module.insert_hits([make_hit(severity=LOW)], source="c", filename="f.txt") db_module.insert_hits([make_hit(severity=CRITICAL, url="admin.testcorp.com")], source="c", filename="f.txt") results = db_module.search("testcorp") assert results[0]["score"] >= results[-1]["score"] # ─── by_severity ────────────────────────────────────────────────────────────── def test_by_severity_returns_correct_severity(): db_module.insert_hits([make_hit(severity=CRITICAL, url="admin.testcorp.com")], source="c", filename="f.txt") db_module.insert_hits([make_hit(severity=LOW)], source="c", filename="f.txt") results = db_module.by_severity(CRITICAL) assert len(results) == 1 assert results[0]["severity"] == CRITICAL def test_by_severity_excludes_duplicates(): """seen_before=1 rows must be invisible to by_severity - they are stored for stats only.""" hit = make_hit(severity=HIGH, url="intranet.testcorp.com") db_module.insert_hits([hit], source="c", filename="f.txt", seen_before=True) assert db_module.by_severity(HIGH) == [] def test_by_severity_returns_empty_when_none(): assert db_module.by_severity(CRITICAL) == [] # ─── stats ─────────────────────────────────────────────────────────────────── def test_stats_counts_by_severity(): db_module.insert_hits([make_hit(severity=CRITICAL, url="admin.testcorp.com")], source="c", filename="f.txt") db_module.insert_hits([make_hit(severity=HIGH, url="intranet.testcorp.com")], source="c", filename="f.txt") db_module.insert_hits([make_hit(severity=MEDIUM, url="app.testcorp.com")], source="c", filename="f.txt") db_module.insert_hits([make_hit(severity=LOW)], source="c", filename="f.txt") s = db_module.stats() assert s["critical"] == 1 assert s["high"] == 1 assert s["medium"] == 1 assert s["low"] == 1 assert s["total"] == 4 assert s["unique"] == 4 assert s["duplicates"] == 0 def test_stats_separates_duplicates(): hit = make_hit() db_module.insert_hits([hit], source="c", filename="f.txt", seen_before=False) db_module.insert_hits([hit], source="c", filename="f.txt", seen_before=True) s = db_module.stats() assert s["total"] == 2 assert s["unique"] == 1 assert s["duplicates"] == 1 def test_stats_severity_counts_exclude_duplicates(): hit = make_hit(severity=CRITICAL, url="admin.testcorp.com") db_module.insert_hits([hit], source="c", filename="f.txt", seen_before=False) db_module.insert_hits([hit], source="c", filename="f.txt", seen_before=True) s = db_module.stats() assert s["critical"] == 1 # only the unique one def test_stats_empty_db(): s = db_module.stats() assert s["total"] == 0 assert s["unique"] == 0 assert s["top_source"] is None def test_stats_top_source(): db_module.insert_hits([make_hit()], source="channelA", filename="f.txt") db_module.insert_hits([make_hit()], source="channelA", filename="f.txt") db_module.insert_hits([make_hit()], source="channelB", filename="f.txt") s = db_module.stats() assert s["top_source"]["source"] == "channelA" # ─── recent ─────────────────────────────────────────────────────────────────── def test_recent_respects_limit(): for i in range(5): db_module.insert_hits([make_hit(raw=f"testcorp.com|user{i}|pass")], source="c", filename="f.txt") rows = db_module.recent(limit=3) assert len(rows) == 3 def test_recent_returns_all_when_under_limit(): db_module.insert_hits([make_hit()], source="c", filename="f.txt") db_module.insert_hits([make_hit()], source="c", filename="f.txt") rows = db_module.recent(limit=50) assert len(rows) == 2