From c603531fd2e3e9899c9d79da31095cff3e8e474e Mon Sep 17 00:00:00 2001 From: anti Date: Wed, 15 Apr 2026 12:51:11 -0400 Subject: [PATCH] feat: add MySQL backend support for DECNET database - Implement MySQLRepository extending BaseRepository - Add SQLAlchemy/SQLModel ORM abstraction layer (sqlmodel_repo.py) - Support connection pooling and tuning via DECNET_DB_URL env var - Cross-compatible with SQLite backend via factory pattern - Prepared for production deployment with MySQL SIEM/ELK integration --- decnet/web/db/mysql/__init__.py | 0 .../__pycache__/__init__.cpython-314.pyc | Bin 0 -> 154 bytes .../__pycache__/database.cpython-314.pyc | Bin 0 -> 4804 bytes .../__pycache__/repository.cpython-314.pyc | Bin 0 -> 6174 bytes decnet/web/db/mysql/database.py | 98 +++ decnet/web/db/mysql/repository.py | 87 +++ decnet/web/db/sqlmodel_repo.py | 637 ++++++++++++++++++ 7 files changed, 822 insertions(+) create mode 100644 decnet/web/db/mysql/__init__.py create mode 100644 decnet/web/db/mysql/__pycache__/__init__.cpython-314.pyc create mode 100644 decnet/web/db/mysql/__pycache__/database.cpython-314.pyc create mode 100644 decnet/web/db/mysql/__pycache__/repository.cpython-314.pyc create mode 100644 decnet/web/db/mysql/database.py create mode 100644 decnet/web/db/mysql/repository.py create mode 100644 decnet/web/db/sqlmodel_repo.py diff --git a/decnet/web/db/mysql/__init__.py b/decnet/web/db/mysql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/decnet/web/db/mysql/__pycache__/__init__.cpython-314.pyc b/decnet/web/db/mysql/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e6b21b730194dae007a0f25139d7ef940c6b7ab GIT binary patch literal 154 zcmdPq7&-6`tikmw)<~5-Zl&4s0{ANXK@Zl(tpN5*4YEEqTQ-No!@fBA3#}t6h3_ zXY~lJ1+G|p?)HT z`-$G63wA>4l|;CYwC+3LKDrQ*Vhd3zz7UfV3vsDK>P-8kF5v3sTpoC@Q}XO1$-e6I z9zw*ktMlh)sA8=eMXDN0nxWFNQnXFBni9n`RZTBhRH48#su!2ll~p>5tl1`~vZQE+ zWz(jRs-+oA)LvF;p`d8y%BpoqFBE8rX_r+7(#A}~P>Z%^8dNjNCec9o%6g)LQARKunf#j-$xZPHA+ zY7{etCDoSshUDgk3gVFzyf&U4JD1JL<7ecTr5O|G2 z8u&5QxJ)lA%*u%Hdb`ZY+4&p@Jy=r9N>xWSb+f4G%cf-)3du*g=VqlxxQ`w^eEczt z=SI(FcZw~kMMHg@%$Mi0(tpW1H#$H6%B(cr&aSRh>{W_JV2Wm`7Ofh(YFV_hYOCO4 zbXL_cz{5!~DL4_uFicxPXIkgrJ~UfFTG4?PFI7!jl`DGHN_w0~TfjKoMW!n7tld3Y zN1TvcQGk@e9}dVX;)UA>hk>Y(JbB&o;!g-E!P+1>FDwg^AbBLO~DDS_7fodNHhW@G1v%lGDIQ+ zjW8!gT#axtN&?~gF(e;=JW}-I+!|QBz>59{AApT&Bu7pN2U?Go_Mml#$OV^|%%L40 z8FTH@TF=lM$6CAELHq`mt7oRD_(T?JM72bH5aDWmh~o@Bc++ zuuY~n?<;fdzQw*Fp0{-K<$YUzT11;(2)PzLA)NPpkCa-LMnw4fZ*`Dktyd=?&$guU z+?oai+YN|vb|{S3=-FL+$p0eI_0N`6_yXwK!bL))n9H#&T!^_l$ot&Yiob2uyR{mS z;`u<7NQt~3=8ilR2%rU+JM*3>IU+*VcU3#IR7@(=jEPAj| zI9(_l0$|8AFX=+z)7o#6v}76wY`O@A@ghYPD*!K0W>caQ--axN7bm<0DeaFCOB^5K zNBb%QhZAVlq4sq|38!mYunPqa;T70*i4TaQSb@&~u@|ZW){2Q3E>eW>?6;6H34uSF!x$6JD)?a5G`9_wsl3bm*IgYb=VOuT3A2!Q`QX^ zCwy{AH30G1h_iPQdk~qIP#KUvb&?yw8cA*O9EVq|Ct>;mw5$(J|3O@T?zhojM{mD+ zmmaUv<9`zWW#N8){M3+;n^?R2)y<<90nVtU+9{d^wV)4y&jVCY%|3+7bo(Xn(IMzn z5Doz&v1MsNB*qrR9S7237eZ*OAxN!t@;yB3Mq4JGc7%*0oOXnCGU#|r%Lzhvp)nKS z23i=$3)tj%6=T&A^NL4hlc38(do>jR6wvxEncLkDPw@$c@bP%#D%jBlW#U?o4lVpRSE> zMiQ-uN8e>P)~>JJc=P(3_5H`!0~`I{ZE_^~?LMKUI&M+9`W-qx)2C{F7kf!(jit?%wUs9$)utbf35udKeSSr#~E=z84>8c!b!f@WAU$hz;Tmi(iT)K5#R%5!+LnxR-e1 z=F=OAAAc@4r28r}NISJNws#^{3bS)-Sz3@V;1&OxJzW4C20-9&YHP z7kGn$X8XDZCGCOI=a%DU1C=NHz5ya05Atq`SA_gl)6kL#X?_9Y zj1L!HVCSA{4wmM$;zg>^;&%TJIh5IM>7mHdPk;KeojJyfEv{m;+O*8f6_C)X%T~&X z?a0Waj{yJ@_AE|^aKfNYdKfCh6AO#Sjv^Cj9X{{@dlA^Ut%rfAkH^($-Wbx}sc_5i-iU zdButH2XLTSE$XTh<9Berkt;3vQY7_i!5k;ebD;)p!J1D0e+$s zNe98^Q66o4-uMtK{L@I(*l=KWVVo1eFB+*D9)j zR>88>O0&^8U{S54tHW*454u3la0s+1;aPVWV*$^vWku6rvTRW+!)H)FXwvrLvSQIi zRW;~EO~*%2;vvq4@q;N5$_#fUclS$yn@s=#32^vr1Xn{t@ zqo-~>Rb%t~XZE_avhnyKz; z->M!kiXv+^TCq{24BE{Ovr)u8Od`)Jk$A`h*^QLhMMV2(uP3p$G?R@IMI+vf%#4+2 zA97Arm;V?t%C-C6bMHO(@7#0G`Rcr0S1S-m;a~h-yHiKVe`3W;K9^wDe+Ag^jSWq&JFno16E)#XaxsDR%lSL1gD)Jsu>Jh;lWy~)~Wl3 z>IUnrdWZH8H4HXdjSd|c+B4W>H92%}C^Fb=HMf&ya*#-&79t5Zn@@1=o~)=;(?R0l z$|npV!pPj%wP8^!mUXqLmMq26jFML)Ush}mOMsaf@k73(nf>J7oMoFI&jrUYawSwwS zYTY)(vaaNzdK=bKMQ4-UCN3ftG(}hQ7L-#{Dm9%&1nAH#wc9)^2(O7}i>ftqHYXRg zX{uPNtXP(ke;0UVXRA5!KjwcdR!o>Zm(xm91}!?Lmd*T(T2yj5vBz0Y5XBkA6iWuM zWu$>CBLlfyJh3(Zv}(z^F)hz%re(lN#hh48+vdM4jgH9I(?j3O4vnN>>Ry;Nm%|1b zTgI{jBN=FZUDQgJI;~PT)_arq8?hv&!t(8=0+w=0(=Dl}-WmWH(bqOl-uPe9FDTVmA&CV|uCwc% z3)|AHN++!3YW5^vJ*e5XJ5e;FIA_l4xm+Thhc`g*hlz(DY?fIUftV-b7qj0HBEt5E%aYQ0-{+qCFS+fP-iLQ6KEHcILz1Z3oQ)y%}D z4%ugVXN;oStCTFQH)|NW*?TFKOsBHF1vOt%t=`+}WN%@z*V$z+-F$tC@|+!(<*FK} zvTVn89XzqUcR#FUwmC=HAiPg_j@g5eH7q_YIJI@?C3!rDm;6A~5E5F~ftV)~o^0TP zBr?wQGUMz;h6mY{HhU(j#*RR1X8T;#IVW&VC!ze_tf~Xdfbw?RIg}@Dg34vkP`_h) zWV=DVQ`V>|8zs=*Vo|fwv;|mzvla^5?`j}LD@;S!QeRbLc}mHnxr%c(53ga$nc1#& z`Q>+SD|Fg4aUWe>N9bN08I)zMq*)-+)|W-<_Bi590P{TgJleGsy|D0|#ps1cZvb_V z|FfFU+mEe=8~*(253hb4eq4FMv7&H~|JPuQ#$d<}HA4|OYKA?pQnOwLM8eTMz(Sib zi(=*+pP{WlyQ<{up5i6beUNov2XfAn4HEVZFwdR*5$~2;djS?F1tI@FA+_sY!*w1a z?aqEO)z0lj2F0~2ov7CP7@4cE6;5WomMid7-hmN!n0ap4Cy+8T&SW;#WNVs0CgW}3}NXTsRbXpYm~Rg&5<2A@>|d7_r`hU?lZ#v(8w^JufD?W zlmYg4kN_F}IsE<m+$Bw+d-UL;+`Rw&7+8` zc96B%#8I1}CR&s((Y0iubJJ=GBu>xUeJ`1^S3sh~Io$w9Z=O$V4n{-gSdwR85*FM0 z|ABvW>}Lo6nqTn$P2Df~|29F@$os8rO!o5K$R%ksP{0Z_ehaal2}q#p;n-Uj^lC7pZMI{=};|;e8VG zZz03utc6LW^~1jVeJf3!OHG|CO+_YtrKov z{RZKh>o@$Qaqs=$kA$@VscqZ{l7`5`X)RIUyth(IP^_+#Np9`dMfbm-0jz5?`fDOOjojk*Q z!XzWXpL0O{-GOAx_p2Hv84dm_ycg*%0rO*xkpJ4(hPk*O^L{4T8vONn9_S}*4Geu! z%Ou|frGaF&XG&A`f_X&j?o8Yjy5n(p?q}ZvV9hv68)qFYrTB>5 z2iN63xIOYaCLk1Nw5-ln;2qhAUV3>i@i-?%>@iPX_(78vqkWH)k57IoEJm&_*AFiS zhqt`7CwTN}a9C#jh%@e$0F?!>#!5a3@Us-`^})eI4wArc1BK>iXavUFf@vw#k}a*M z+Cjj`F4cfvA{5RIPpJei*0ecjaWKeT0Y5q3E@*k{=C|JDDOkch=j?NX9Q3t3$L!x5 z_OJP%T=SFWLu&z~f}}}Y3n3+teFxTRkO~ta^1K!(R|MOE1B&4Lc;AOW5Uw%Zy7O`V z7IUJi9kk8c7f67Thv{aavV+&_W9$w`?_py9O; z_~`cF#W&8mxuUn%4C2{cX#zqDZra>I0PuCA*Sb~Sn7z?)0_`!Ra$r1=K^(m+Z9Dzu ziM)9;h%a}g|BD~iAQ?oWyOMY2Z-j&X?NF{991mtt_jjf7fVU?AMmtA@5~M7QhcbwA zcO`G#26A~ABv3P6v(v4X!&xDN=zUj`0u$|C(|A}4dbi<$x&p`%KwUU<(Bs-FrH~}N zS$&oxOg1d}U{z+jGwV5>jKZ0MV3hw;x zzzw@_`dh%Ae+Rg$d4{{Z`^JtyL{HDo2t*Is0I&wda$LMLG%nud0${mWY!}jUQq^eW=`lI+1=@o3Pal;cb z28SqsZ(`TUO1Z4hfd$pU@|bpTni`d|JUM3v4O&nsP!z9#9z|{rVS^t=fW@){5TOCc zLFo{5#cN#)eI9wgi`fOt1|YNRs^>6O02u{y^IQm&><}E?a>8GBp!x-A2dHXRbPKE> zU|Wjz)@Ci!#E062$&2;4KA`cXV32)qw+ncQ#j%iSxE{nLFX~9!{>A-oEX*$NJ@)A0 zzwAAAukJ;Vv>aZo>wBDD?7RA#6N{N!i=*m1u6d@0ob@bCw<%ff**9|=W( z{qXAxegDY+lkf}Sm*JJuZ!MjE>o;#MpB`K8%&v6imO68vDTU=u^c?z+S`HrtYOC zfG=_Q|Ewv#aAP&r@o@B`(S@Pq*r}E1se4y|t^MG=}YPLBNcI?98`Pi zWzctOS~|qgN(OWXWVvAEWtn#3v?np^!)yXFyIGb|N3}c#GpRO-ei>NN2skaWC-=ApzbPgNj#1zV=* zN^QS2SB4NIMgK|B%>Zzi)fF8-uIAi_?L2Wr_U^-vZuKW12AOc;g__~SXq;$`-2xmS z7*$UI0X*=z>BJu?e0vi+;)cNz5C%&WCKC>$>Dr3et}cw@CGcZ%65`ZuSZLo)7H=gP z<7gC}sw1LZva^Hd+>1I-i@*Ypl9=dM);NY?p0hO!|5YntI=& str: + """Compose an async SQLAlchemy URL for MySQL using the aiomysql driver. + + Component args override env vars. Password is percent-encoded so special + characters (``@``, ``:``, ``/``…) don't break URL parsing. + """ + host = host or os.environ.get("DECNET_DB_HOST", "localhost") + port = port or int(os.environ.get("DECNET_DB_PORT", "3306")) + database = database or os.environ.get("DECNET_DB_NAME", "decnet") + user = user or os.environ.get("DECNET_DB_USER", "decnet") + + if password is None: + password = os.environ.get("DECNET_DB_PASSWORD", "") + + # Allow empty passwords during tests (pytest sets PYTEST_* env vars). + # Outside tests, an empty MySQL password is almost never intentional. + if not password and not any(k.startswith("PYTEST") for k in os.environ): + raise ValueError( + "DECNET_DB_PASSWORD is not set. Either export it, set DECNET_DB_URL, " + "or run under pytest for an empty-password default." + ) + + pw_enc = quote_plus(password) + user_enc = quote_plus(user) + return f"mysql+aiomysql://{user_enc}:{pw_enc}@{host}:{port}/{database}" + + +def resolve_url(url: Optional[str] = None) -> str: + """Pick a connection URL: explicit arg → DECNET_DB_URL env → built from components.""" + if url: + return url + env_url = os.environ.get("DECNET_DB_URL") + if env_url: + return env_url + return build_mysql_url() + + +def get_async_engine( + url: Optional[str] = None, + *, + pool_size: int = DEFAULT_POOL_SIZE, + max_overflow: int = DEFAULT_MAX_OVERFLOW, + pool_recycle: int = DEFAULT_POOL_RECYCLE, + pool_pre_ping: bool = DEFAULT_POOL_PRE_PING, + echo: bool = False, +) -> AsyncEngine: + """Create an AsyncEngine for MySQL. + + Defaults tuned for a dashboard workload: a modest pool, hourly recycle + to sidestep MySQL's idle-connection reaper, and pre-ping to fail fast + if a pooled connection has been killed server-side. + """ + dsn = resolve_url(url) + return create_async_engine( + dsn, + echo=echo, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + pool_pre_ping=pool_pre_ping, + ) diff --git a/decnet/web/db/mysql/repository.py b/decnet/web/db/mysql/repository.py new file mode 100644 index 0000000..533b061 --- /dev/null +++ b/decnet/web/db/mysql/repository.py @@ -0,0 +1,87 @@ +""" +MySQL implementation of :class:`BaseRepository`. + +Inherits the portable SQLModel query code from :class:`SQLModelRepository` +and only overrides the two places where MySQL's SQL dialect differs from +SQLite's: + +* :meth:`_migrate_attackers_table` — uses ``information_schema`` (MySQL + has no ``PRAGMA``). +* :meth:`get_log_histogram` — uses ``FROM_UNIXTIME`` / + ``UNIX_TIMESTAMP`` + integer division for bucketing. +""" +from __future__ import annotations + +from typing import List, Optional + +from sqlalchemy import func, select, text, literal_column +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sqlmodel.sql.expression import SelectOfScalar + +from decnet.web.db.models import Log +from decnet.web.db.mysql.database import get_async_engine +from decnet.web.db.sqlmodel_repo import SQLModelRepository + + +class MySQLRepository(SQLModelRepository): + """MySQL backend — uses ``aiomysql``.""" + + def __init__(self, url: Optional[str] = None, **engine_kwargs) -> None: + self.engine = get_async_engine(url=url, **engine_kwargs) + self.session_factory = async_sessionmaker( + self.engine, class_=AsyncSession, expire_on_commit=False + ) + + async def _migrate_attackers_table(self) -> None: + """Drop the legacy (pre-UUID) ``attackers`` table if it exists without a ``uuid`` column. + + MySQL exposes column metadata via ``information_schema.COLUMNS``. + ``DATABASE()`` scopes the lookup to the currently connected schema. + """ + async with self.engine.begin() as conn: + rows = (await conn.execute(text( + "SELECT COLUMN_NAME FROM information_schema.COLUMNS " + "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'attackers'" + ))).fetchall() + if rows and not any(r[0] == "uuid" for r in rows): + await conn.execute(text("DROP TABLE attackers")) + + def _json_field_equals(self, key: str): + # MySQL 5.7+ exposes JSON_EXTRACT; quoted string result returned for + # TEXT-stored JSON, same behavior we rely on in SQLite. + return text(f"JSON_UNQUOTE(JSON_EXTRACT(fields, '$.{key}')) = :val") + + async def get_log_histogram( + self, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + interval_minutes: int = 15, + ) -> List[dict]: + bucket_seconds = max(interval_minutes, 1) * 60 + # Truncate each timestamp to the start of its bucket: + # FROM_UNIXTIME( (UNIX_TIMESTAMP(timestamp) DIV N) * N ) + # DIV is MySQL's integer division operator. + bucket_expr = literal_column( + f"FROM_UNIXTIME((UNIX_TIMESTAMP(timestamp) DIV {bucket_seconds}) * {bucket_seconds})" + ).label("bucket_time") + + statement: SelectOfScalar = select(bucket_expr, func.count().label("count")).select_from(Log) + statement = self._apply_filters(statement, search, start_time, end_time) + statement = statement.group_by(literal_column("bucket_time")).order_by( + literal_column("bucket_time") + ) + + async with self.session_factory() as session: + results = await session.execute(statement) + # Normalize to ISO string for API parity with the SQLite backend + # (SQLite's datetime() returns a string already; FROM_UNIXTIME + # returns a datetime). + out: List[dict] = [] + for r in results.all(): + ts = r[0] + out.append({ + "time": ts.isoformat(sep=" ") if hasattr(ts, "isoformat") else ts, + "count": r[1], + }) + return out diff --git a/decnet/web/db/sqlmodel_repo.py b/decnet/web/db/sqlmodel_repo.py new file mode 100644 index 0000000..e50b652 --- /dev/null +++ b/decnet/web/db/sqlmodel_repo.py @@ -0,0 +1,637 @@ +""" +Shared SQLModel-based repository implementation. + +Contains all dialect-portable query code used by the SQLite and MySQL +backends. Dialect-specific behavior lives in subclasses: + +* engine/session construction (``__init__``) +* ``_migrate_attackers_table`` (legacy schema check; DDL introspection + is not portable) +* ``get_log_histogram`` (date-bucket expression differs per dialect) +""" +from __future__ import annotations + +import asyncio +import json +import uuid +from datetime import datetime, timezone +from typing import Any, Optional, List + +from sqlalchemy import func, select, desc, asc, text, or_, update +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker +from sqlmodel.sql.expression import SelectOfScalar + +from decnet.config import load_state +from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD +from decnet.web.auth import get_password_hash +from decnet.web.db.repository import BaseRepository +from decnet.web.db.models import User, Log, Bounty, State, Attacker, AttackerBehavior + + +class SQLModelRepository(BaseRepository): + """Concrete SQLModel/SQLAlchemy-async repository. + + Subclasses provide ``self.engine`` (AsyncEngine) and ``self.session_factory`` + in ``__init__``, and override the few dialect-specific helpers. + """ + + engine: AsyncEngine + session_factory: async_sessionmaker[AsyncSession] + + # ------------------------------------------------------------ lifecycle + + async def initialize(self) -> None: + """Create tables if absent and seed the admin user.""" + from sqlmodel import SQLModel + await self._migrate_attackers_table() + async with self.engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + await self._ensure_admin_user() + + async def reinitialize(self) -> None: + """Re-create schema (for tests / reset flows). Does NOT drop existing tables.""" + from sqlmodel import SQLModel + async with self.engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + await self._ensure_admin_user() + + async def _ensure_admin_user(self) -> None: + async with self.session_factory() as session: + result = await session.execute( + select(User).where(User.username == DECNET_ADMIN_USER) + ) + if not result.scalar_one_or_none(): + session.add(User( + uuid=str(uuid.uuid4()), + username=DECNET_ADMIN_USER, + password_hash=get_password_hash(DECNET_ADMIN_PASSWORD), + role="admin", + must_change_password=True, + )) + await session.commit() + + async def _migrate_attackers_table(self) -> None: + """Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable).""" + return None + + # ---------------------------------------------------------------- logs + + async def add_log(self, log_data: dict[str, Any]) -> None: + data = log_data.copy() + if "fields" in data and isinstance(data["fields"], dict): + data["fields"] = json.dumps(data["fields"]) + if "timestamp" in data and isinstance(data["timestamp"], str): + try: + data["timestamp"] = datetime.fromisoformat( + data["timestamp"].replace("Z", "+00:00") + ) + except ValueError: + pass + + async with self.session_factory() as session: + session.add(Log(**data)) + await session.commit() + + def _apply_filters( + self, + statement: SelectOfScalar, + search: Optional[str], + start_time: Optional[str], + end_time: Optional[str], + ) -> SelectOfScalar: + import re + import shlex + + if start_time: + statement = statement.where(Log.timestamp >= start_time) + if end_time: + statement = statement.where(Log.timestamp <= end_time) + + if search: + try: + tokens = shlex.split(search) + except ValueError: + tokens = search.split() + + core_fields = { + "decky": Log.decky, + "service": Log.service, + "event": Log.event_type, + "attacker": Log.attacker_ip, + "attacker-ip": Log.attacker_ip, + "attacker_ip": Log.attacker_ip, + } + + for token in tokens: + if ":" in token: + key, val = token.split(":", 1) + if key in core_fields: + statement = statement.where(core_fields[key] == val) + else: + key_safe = re.sub(r"[^a-zA-Z0-9_]", "", key) + if key_safe: + statement = statement.where( + self._json_field_equals(key_safe) + ).params(val=val) + else: + lk = f"%{token}%" + statement = statement.where( + or_( + Log.raw_line.like(lk), + Log.decky.like(lk), + Log.service.like(lk), + Log.attacker_ip.like(lk), + ) + ) + return statement + + def _json_field_equals(self, key: str): + """Return a text() predicate that matches rows where fields->key == :val. + + Both SQLite and MySQL expose a ``JSON_EXTRACT`` function; MySQL also + exposes the same function under ``json_extract`` (case-insensitive). + The ``:val`` parameter is bound separately and must be supplied with + ``.params(val=...)`` by the caller, which keeps us safe from injection. + """ + return text(f"JSON_EXTRACT(fields, '$.{key}') = :val") + + async def get_logs( + self, + limit: int = 50, + offset: int = 0, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + ) -> List[dict]: + statement = ( + select(Log) + .order_by(desc(Log.timestamp)) + .offset(offset) + .limit(limit) + ) + statement = self._apply_filters(statement, search, start_time, end_time) + + async with self.session_factory() as session: + results = await session.execute(statement) + return [log.model_dump(mode="json") for log in results.scalars().all()] + + async def get_max_log_id(self) -> int: + async with self.session_factory() as session: + result = await session.execute(select(func.max(Log.id))) + val = result.scalar() + return val if val is not None else 0 + + async def get_logs_after_id( + self, + last_id: int, + limit: int = 50, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + ) -> List[dict]: + statement = ( + select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit) + ) + statement = self._apply_filters(statement, search, start_time, end_time) + + async with self.session_factory() as session: + results = await session.execute(statement) + return [log.model_dump(mode="json") for log in results.scalars().all()] + + async def get_total_logs( + self, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + ) -> int: + statement = select(func.count()).select_from(Log) + statement = self._apply_filters(statement, search, start_time, end_time) + + async with self.session_factory() as session: + result = await session.execute(statement) + return result.scalar() or 0 + + async def get_log_histogram( + self, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + interval_minutes: int = 15, + ) -> List[dict]: + """Dialect-specific — override per backend.""" + raise NotImplementedError + + async def get_stats_summary(self) -> dict[str, Any]: + async with self.session_factory() as session: + total_logs = ( + await session.execute(select(func.count()).select_from(Log)) + ).scalar() or 0 + unique_attackers = ( + await session.execute( + select(func.count(func.distinct(Log.attacker_ip))) + ) + ).scalar() or 0 + + _state = await asyncio.to_thread(load_state) + deployed_deckies = len(_state[0].deckies) if _state else 0 + + return { + "total_logs": total_logs, + "unique_attackers": unique_attackers, + "active_deckies": deployed_deckies, + "deployed_deckies": deployed_deckies, + } + + async def get_deckies(self) -> List[dict]: + _state = await asyncio.to_thread(load_state) + return [_d.model_dump() for _d in _state[0].deckies] if _state else [] + + # --------------------------------------------------------------- users + + async def get_user_by_username(self, username: str) -> Optional[dict]: + async with self.session_factory() as session: + result = await session.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + return user.model_dump() if user else None + + async def get_user_by_uuid(self, uuid: str) -> Optional[dict]: + async with self.session_factory() as session: + result = await session.execute( + select(User).where(User.uuid == uuid) + ) + user = result.scalar_one_or_none() + return user.model_dump() if user else None + + async def create_user(self, user_data: dict[str, Any]) -> None: + async with self.session_factory() as session: + session.add(User(**user_data)) + await session.commit() + + async def update_user_password( + self, uuid: str, password_hash: str, must_change_password: bool = False + ) -> None: + async with self.session_factory() as session: + await session.execute( + update(User) + .where(User.uuid == uuid) + .values( + password_hash=password_hash, + must_change_password=must_change_password, + ) + ) + await session.commit() + + async def list_users(self) -> list[dict]: + async with self.session_factory() as session: + result = await session.execute(select(User)) + return [u.model_dump() for u in result.scalars().all()] + + async def delete_user(self, uuid: str) -> bool: + async with self.session_factory() as session: + result = await session.execute(select(User).where(User.uuid == uuid)) + user = result.scalar_one_or_none() + if not user: + return False + await session.delete(user) + await session.commit() + return True + + async def update_user_role(self, uuid: str, role: str) -> None: + async with self.session_factory() as session: + await session.execute( + update(User).where(User.uuid == uuid).values(role=role) + ) + await session.commit() + + async def purge_logs_and_bounties(self) -> dict[str, int]: + async with self.session_factory() as session: + logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount + bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount + # attacker_behavior has FK → attackers.uuid; delete children first. + await session.execute(text("DELETE FROM attacker_behavior")) + attackers_deleted = (await session.execute(text("DELETE FROM attackers"))).rowcount + await session.commit() + return { + "logs": logs_deleted, + "bounties": bounties_deleted, + "attackers": attackers_deleted, + } + + # ------------------------------------------------------------ bounties + + async def add_bounty(self, bounty_data: dict[str, Any]) -> None: + data = bounty_data.copy() + if "payload" in data and isinstance(data["payload"], dict): + data["payload"] = json.dumps(data["payload"]) + + async with self.session_factory() as session: + session.add(Bounty(**data)) + await session.commit() + + def _apply_bounty_filters( + self, + statement: SelectOfScalar, + bounty_type: Optional[str], + search: Optional[str], + ) -> SelectOfScalar: + if bounty_type: + statement = statement.where(Bounty.bounty_type == bounty_type) + if search: + lk = f"%{search}%" + statement = statement.where( + or_( + Bounty.decky.like(lk), + Bounty.service.like(lk), + Bounty.attacker_ip.like(lk), + Bounty.payload.like(lk), + ) + ) + return statement + + async def get_bounties( + self, + limit: int = 50, + offset: int = 0, + bounty_type: Optional[str] = None, + search: Optional[str] = None, + ) -> List[dict]: + statement = ( + select(Bounty) + .order_by(desc(Bounty.timestamp)) + .offset(offset) + .limit(limit) + ) + statement = self._apply_bounty_filters(statement, bounty_type, search) + + async with self.session_factory() as session: + results = await session.execute(statement) + final = [] + for item in results.scalars().all(): + d = item.model_dump(mode="json") + try: + d["payload"] = json.loads(d["payload"]) + except (json.JSONDecodeError, TypeError): + pass + final.append(d) + return final + + async def get_total_bounties( + self, bounty_type: Optional[str] = None, search: Optional[str] = None + ) -> int: + statement = select(func.count()).select_from(Bounty) + statement = self._apply_bounty_filters(statement, bounty_type, search) + + async with self.session_factory() as session: + result = await session.execute(statement) + return result.scalar() or 0 + + async def get_state(self, key: str) -> Optional[dict[str, Any]]: + async with self.session_factory() as session: + statement = select(State).where(State.key == key) + result = await session.execute(statement) + state = result.scalar_one_or_none() + if state: + return json.loads(state.value) + return None + + async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401 + async with self.session_factory() as session: + statement = select(State).where(State.key == key) + result = await session.execute(statement) + state = result.scalar_one_or_none() + + value_json = json.dumps(value) + if state: + state.value = value_json + session.add(state) + else: + session.add(State(key=key, value=value_json)) + + await session.commit() + + # ----------------------------------------------------------- attackers + + async def get_all_logs_raw(self) -> List[dict[str, Any]]: + async with self.session_factory() as session: + result = await session.execute( + select( + Log.id, + Log.raw_line, + Log.attacker_ip, + Log.service, + Log.event_type, + Log.decky, + Log.timestamp, + Log.fields, + ) + ) + return [ + { + "id": r.id, + "raw_line": r.raw_line, + "attacker_ip": r.attacker_ip, + "service": r.service, + "event_type": r.event_type, + "decky": r.decky, + "timestamp": r.timestamp, + "fields": r.fields, + } + for r in result.all() + ] + + async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]: + from collections import defaultdict + async with self.session_factory() as session: + result = await session.execute( + select(Bounty).order_by(asc(Bounty.timestamp)) + ) + grouped: dict[str, List[dict[str, Any]]] = defaultdict(list) + for item in result.scalars().all(): + d = item.model_dump(mode="json") + try: + d["payload"] = json.loads(d["payload"]) + except (json.JSONDecodeError, TypeError): + pass + grouped[item.attacker_ip].append(d) + return dict(grouped) + + async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]: + from collections import defaultdict + async with self.session_factory() as session: + result = await session.execute( + select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp)) + ) + grouped: dict[str, List[dict[str, Any]]] = defaultdict(list) + for item in result.scalars().all(): + d = item.model_dump(mode="json") + try: + d["payload"] = json.loads(d["payload"]) + except (json.JSONDecodeError, TypeError): + pass + grouped[item.attacker_ip].append(d) + return dict(grouped) + + async def upsert_attacker(self, data: dict[str, Any]) -> str: + async with self.session_factory() as session: + result = await session.execute( + select(Attacker).where(Attacker.ip == data["ip"]) + ) + existing = result.scalar_one_or_none() + if existing: + for k, v in data.items(): + setattr(existing, k, v) + session.add(existing) + row_uuid = existing.uuid + else: + row_uuid = str(uuid.uuid4()) + data = {**data, "uuid": row_uuid} + session.add(Attacker(**data)) + await session.commit() + return row_uuid + + async def upsert_attacker_behavior( + self, + attacker_uuid: str, + data: dict[str, Any], + ) -> None: + async with self.session_factory() as session: + result = await session.execute( + select(AttackerBehavior).where( + AttackerBehavior.attacker_uuid == attacker_uuid + ) + ) + existing = result.scalar_one_or_none() + payload = {**data, "updated_at": datetime.now(timezone.utc)} + if existing: + for k, v in payload.items(): + setattr(existing, k, v) + session.add(existing) + else: + session.add(AttackerBehavior(attacker_uuid=attacker_uuid, **payload)) + await session.commit() + + async def get_attacker_behavior( + self, + attacker_uuid: str, + ) -> Optional[dict[str, Any]]: + async with self.session_factory() as session: + result = await session.execute( + select(AttackerBehavior).where( + AttackerBehavior.attacker_uuid == attacker_uuid + ) + ) + row = result.scalar_one_or_none() + if not row: + return None + return self._deserialize_behavior(row.model_dump(mode="json")) + + async def get_behaviors_for_ips( + self, + ips: set[str], + ) -> dict[str, dict[str, Any]]: + if not ips: + return {} + async with self.session_factory() as session: + result = await session.execute( + select(Attacker.ip, AttackerBehavior) + .join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid) + .where(Attacker.ip.in_(ips)) + ) + out: dict[str, dict[str, Any]] = {} + for ip, row in result.all(): + out[ip] = self._deserialize_behavior(row.model_dump(mode="json")) + return out + + @staticmethod + def _deserialize_behavior(d: dict[str, Any]) -> dict[str, Any]: + for key in ("tcp_fingerprint", "timing_stats", "phase_sequence"): + if isinstance(d.get(key), str): + try: + d[key] = json.loads(d[key]) + except (json.JSONDecodeError, TypeError): + pass + return d + + @staticmethod + def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]: + for key in ("services", "deckies", "fingerprints", "commands"): + if isinstance(d.get(key), str): + try: + d[key] = json.loads(d[key]) + except (json.JSONDecodeError, TypeError): + pass + return d + + async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]: + async with self.session_factory() as session: + result = await session.execute( + select(Attacker).where(Attacker.uuid == uuid) + ) + attacker = result.scalar_one_or_none() + if not attacker: + return None + return self._deserialize_attacker(attacker.model_dump(mode="json")) + + async def get_attackers( + self, + limit: int = 50, + offset: int = 0, + search: Optional[str] = None, + sort_by: str = "recent", + service: Optional[str] = None, + ) -> List[dict[str, Any]]: + order = { + "active": desc(Attacker.event_count), + "traversals": desc(Attacker.is_traversal), + }.get(sort_by, desc(Attacker.last_seen)) + + statement = select(Attacker).order_by(order).offset(offset).limit(limit) + if search: + statement = statement.where(Attacker.ip.like(f"%{search}%")) + if service: + statement = statement.where(Attacker.services.like(f'%"{service}"%')) + + async with self.session_factory() as session: + result = await session.execute(statement) + return [ + self._deserialize_attacker(a.model_dump(mode="json")) + for a in result.scalars().all() + ] + + async def get_total_attackers( + self, search: Optional[str] = None, service: Optional[str] = None + ) -> int: + statement = select(func.count()).select_from(Attacker) + if search: + statement = statement.where(Attacker.ip.like(f"%{search}%")) + if service: + statement = statement.where(Attacker.services.like(f'%"{service}"%')) + + async with self.session_factory() as session: + result = await session.execute(statement) + return result.scalar() or 0 + + async def get_attacker_commands( + self, + uuid: str, + limit: int = 50, + offset: int = 0, + service: Optional[str] = None, + ) -> dict[str, Any]: + async with self.session_factory() as session: + result = await session.execute( + select(Attacker.commands).where(Attacker.uuid == uuid) + ) + raw = result.scalar_one_or_none() + if raw is None: + return {"total": 0, "data": []} + + commands: list = json.loads(raw) if isinstance(raw, str) else raw + if service: + commands = [c for c in commands if c.get("service") == service] + + total = len(commands) + page = commands[offset: offset + limit] + return {"total": total, "data": page}