fix: clean up db layer — model_dump, timezone-aware timestamps, unified histogram, async load_state
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from typing import Optional, Any, List
|
from typing import Optional, Any, List
|
||||||
from sqlmodel import SQLModel, Field, Column, JSON
|
from sqlmodel import SQLModel, Field
|
||||||
from pydantic import BaseModel, Field as PydanticField
|
from pydantic import BaseModel, Field as PydanticField
|
||||||
|
|
||||||
# --- Database Tables (SQLModel) ---
|
# --- Database Tables (SQLModel) ---
|
||||||
@@ -16,7 +16,7 @@ class User(SQLModel, table=True):
|
|||||||
class Log(SQLModel, table=True):
|
class Log(SQLModel, table=True):
|
||||||
__tablename__ = "logs"
|
__tablename__ = "logs"
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
timestamp: datetime = Field(default_factory=datetime.utcnow, index=True)
|
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True)
|
||||||
decky: str = Field(index=True)
|
decky: str = Field(index=True)
|
||||||
service: str = Field(index=True)
|
service: str = Field(index=True)
|
||||||
event_type: str = Field(index=True)
|
event_type: str = Field(index=True)
|
||||||
@@ -28,7 +28,7 @@ class Log(SQLModel, table=True):
|
|||||||
class Bounty(SQLModel, table=True):
|
class Bounty(SQLModel, table=True):
|
||||||
__tablename__ = "bounty"
|
__tablename__ = "bounty"
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
timestamp: datetime = Field(default_factory=datetime.utcnow, index=True)
|
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True)
|
||||||
decky: str = Field(index=True)
|
decky: str = Field(index=True)
|
||||||
service: str = Field(index=True)
|
service: str = Field(index=True)
|
||||||
attacker_ip: str = Field(index=True)
|
attacker_ip: str = Field(index=True)
|
||||||
|
|||||||
@@ -4,9 +4,8 @@ import uuid
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional, List
|
from typing import Any, Optional, List
|
||||||
|
|
||||||
from sqlalchemy import func, select, desc, asc, text, or_, update
|
from sqlalchemy import func, select, desc, asc, text, or_, update, literal_column
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
from sqlmodel import col
|
|
||||||
|
|
||||||
from decnet.config import load_state, _ROOT
|
from decnet.config import load_state, _ROOT
|
||||||
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
||||||
@@ -31,59 +30,81 @@ class SQLiteRepository(BaseRepository):
|
|||||||
"""Initialize the database schema synchronously."""
|
"""Initialize the database schema synchronously."""
|
||||||
init_db(self.db_path)
|
init_db(self.db_path)
|
||||||
|
|
||||||
# Ensure default admin exists via sync SQLAlchemy connection
|
|
||||||
from sqlalchemy import select
|
|
||||||
from decnet.web.db.sqlite.database import get_sync_engine
|
from decnet.web.db.sqlite.database import get_sync_engine
|
||||||
engine = get_sync_engine(self.db_path)
|
engine = get_sync_engine(self.db_path)
|
||||||
with engine.connect() as conn:
|
with engine.connect() as conn:
|
||||||
# Ensure admin exists via sync SQLAlchemy connection
|
result = conn.execute(
|
||||||
from sqlalchemy import text
|
text("SELECT uuid FROM users WHERE username = :u"),
|
||||||
result = conn.execute(text("SELECT uuid FROM users WHERE username = :u"), {"u": DECNET_ADMIN_USER})
|
{"u": DECNET_ADMIN_USER},
|
||||||
|
)
|
||||||
if not result.fetchone():
|
if not result.fetchone():
|
||||||
print(f"DEBUG: Creating admin user '{DECNET_ADMIN_USER}' with password '{DECNET_ADMIN_PASSWORD}'")
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
text("INSERT INTO users (uuid, username, password_hash, role, must_change_password) "
|
text(
|
||||||
"VALUES (:uuid, :u, :p, :r, :m)"),
|
"INSERT INTO users (uuid, username, password_hash, role, must_change_password) "
|
||||||
|
"VALUES (:uuid, :u, :p, :r, :m)"
|
||||||
|
),
|
||||||
{
|
{
|
||||||
"uuid": str(uuid.uuid4()),
|
"uuid": str(uuid.uuid4()),
|
||||||
"u": DECNET_ADMIN_USER,
|
"u": DECNET_ADMIN_USER,
|
||||||
"p": get_password_hash(DECNET_ADMIN_PASSWORD),
|
"p": get_password_hash(DECNET_ADMIN_PASSWORD),
|
||||||
"r": "admin",
|
"r": "admin",
|
||||||
"m": 1
|
"m": 1,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
else:
|
|
||||||
print(f"DEBUG: Admin user '{DECNET_ADMIN_USER}' already exists")
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""Async warm-up / verification."""
|
"""Async warm-up / verification."""
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
await session.exec(text("SELECT 1"))
|
await session.execute(text("SELECT 1"))
|
||||||
|
|
||||||
def reinitialize(self) -> None:
|
async def reinitialize(self) -> None:
|
||||||
self._initialize_sync()
|
"""Initialize the database schema asynchronously (useful for tests)."""
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
async with self.engine.begin() as conn:
|
||||||
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(User).where(User.username == DECNET_ADMIN_USER)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
session.add(User(
|
||||||
|
uuid=str(uuid.uuid4()),
|
||||||
|
username=DECNET_ADMIN_USER,
|
||||||
|
password_hash=get_password_hash(DECNET_ADMIN_PASSWORD),
|
||||||
|
role="admin",
|
||||||
|
must_change_password=True,
|
||||||
|
))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ logs
|
||||||
|
|
||||||
async def add_log(self, log_data: dict[str, Any]) -> None:
|
async def add_log(self, log_data: dict[str, Any]) -> None:
|
||||||
# Convert dict to model
|
|
||||||
data = log_data.copy()
|
data = log_data.copy()
|
||||||
if "fields" in data and isinstance(data["fields"], dict):
|
if "fields" in data and isinstance(data["fields"], dict):
|
||||||
data["fields"] = json.dumps(data["fields"])
|
data["fields"] = json.dumps(data["fields"])
|
||||||
|
|
||||||
if "timestamp" in data and isinstance(data["timestamp"], str):
|
if "timestamp" in data and isinstance(data["timestamp"], str):
|
||||||
try:
|
try:
|
||||||
data["timestamp"] = datetime.fromisoformat(data["timestamp"].replace('Z', '+00:00'))
|
data["timestamp"] = datetime.fromisoformat(
|
||||||
|
data["timestamp"].replace("Z", "+00:00")
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log = Log(**data)
|
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
session.add(log)
|
session.add(Log(**data))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
def _apply_filters(self, statement, search: Optional[str], start_time: Optional[str], end_time: Optional[str]):
|
def _apply_filters(
|
||||||
import shlex
|
self,
|
||||||
|
statement,
|
||||||
|
search: Optional[str],
|
||||||
|
start_time: Optional[str],
|
||||||
|
end_time: Optional[str],
|
||||||
|
):
|
||||||
import re
|
import re
|
||||||
|
import shlex
|
||||||
|
|
||||||
if start_time:
|
if start_time:
|
||||||
statement = statement.where(Log.timestamp >= start_time)
|
statement = statement.where(Log.timestamp >= start_time)
|
||||||
@@ -94,7 +115,7 @@ class SQLiteRepository(BaseRepository):
|
|||||||
try:
|
try:
|
||||||
tokens = shlex.split(search)
|
tokens = shlex.split(search)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
tokens = search.split(" ")
|
tokens = search.split()
|
||||||
|
|
||||||
core_fields = {
|
core_fields = {
|
||||||
"decky": Log.decky,
|
"decky": Log.decky,
|
||||||
@@ -102,7 +123,7 @@ class SQLiteRepository(BaseRepository):
|
|||||||
"event": Log.event_type,
|
"event": Log.event_type,
|
||||||
"attacker": Log.attacker_ip,
|
"attacker": Log.attacker_ip,
|
||||||
"attacker-ip": Log.attacker_ip,
|
"attacker-ip": Log.attacker_ip,
|
||||||
"attacker_ip": Log.attacker_ip
|
"attacker_ip": Log.attacker_ip,
|
||||||
}
|
}
|
||||||
|
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
@@ -111,9 +132,10 @@ class SQLiteRepository(BaseRepository):
|
|||||||
if key in core_fields:
|
if key in core_fields:
|
||||||
statement = statement.where(core_fields[key] == val)
|
statement = statement.where(core_fields[key] == val)
|
||||||
else:
|
else:
|
||||||
key_safe = re.sub(r'[^a-zA-Z0-9_]', '', key)
|
key_safe = re.sub(r"[^a-zA-Z0-9_]", "", key)
|
||||||
# SQLite json_extract via text()
|
statement = statement.where(
|
||||||
statement = statement.where(text(f"json_extract(fields, '$.{key_safe}') = :val")).params(val=val)
|
text(f"json_extract(fields, '$.{key_safe}') = :val")
|
||||||
|
).params(val=val)
|
||||||
else:
|
else:
|
||||||
lk = f"%{token}%"
|
lk = f"%{token}%"
|
||||||
statement = statement.where(
|
statement = statement.where(
|
||||||
@@ -121,7 +143,7 @@ class SQLiteRepository(BaseRepository):
|
|||||||
Log.raw_line.like(lk),
|
Log.raw_line.like(lk),
|
||||||
Log.decky.like(lk),
|
Log.decky.like(lk),
|
||||||
Log.service.like(lk),
|
Log.service.like(lk),
|
||||||
Log.attacker_ip.like(lk)
|
Log.attacker_ip.like(lk),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return statement
|
return statement
|
||||||
@@ -132,14 +154,19 @@ class SQLiteRepository(BaseRepository):
|
|||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
search: Optional[str] = None,
|
search: Optional[str] = None,
|
||||||
start_time: Optional[str] = None,
|
start_time: Optional[str] = None,
|
||||||
end_time: Optional[str] = None
|
end_time: Optional[str] = None,
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
statement = select(Log).order_by(desc(Log.timestamp)).offset(offset).limit(limit)
|
statement = (
|
||||||
|
select(Log)
|
||||||
|
.order_by(desc(Log.timestamp))
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
statement = self._apply_filters(statement, search, start_time, end_time)
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
results = await session.execute(statement)
|
results = await session.execute(statement)
|
||||||
return [log.dict() for log in results.scalars().all()]
|
return [log.model_dump() for log in results.scalars().all()]
|
||||||
|
|
||||||
async def get_max_log_id(self) -> int:
|
async def get_max_log_id(self) -> int:
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
@@ -153,20 +180,22 @@ class SQLiteRepository(BaseRepository):
|
|||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
search: Optional[str] = None,
|
search: Optional[str] = None,
|
||||||
start_time: Optional[str] = None,
|
start_time: Optional[str] = None,
|
||||||
end_time: Optional[str] = None
|
end_time: Optional[str] = None,
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
statement = select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit)
|
statement = (
|
||||||
|
select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit)
|
||||||
|
)
|
||||||
statement = self._apply_filters(statement, search, start_time, end_time)
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
results = await session.execute(statement)
|
results = await session.execute(statement)
|
||||||
return [log.dict() for log in results.scalars().all()]
|
return [log.model_dump() for log in results.scalars().all()]
|
||||||
|
|
||||||
async def get_total_logs(
|
async def get_total_logs(
|
||||||
self,
|
self,
|
||||||
search: Optional[str] = None,
|
search: Optional[str] = None,
|
||||||
start_time: Optional[str] = None,
|
start_time: Optional[str] = None,
|
||||||
end_time: Optional[str] = None
|
end_time: Optional[str] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
statement = select(func.count()).select_from(Log)
|
statement = select(func.count()).select_from(Log)
|
||||||
statement = self._apply_filters(statement, search, start_time, end_time)
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
@@ -180,130 +209,99 @@ class SQLiteRepository(BaseRepository):
|
|||||||
search: Optional[str] = None,
|
search: Optional[str] = None,
|
||||||
start_time: Optional[str] = None,
|
start_time: Optional[str] = None,
|
||||||
end_time: Optional[str] = None,
|
end_time: Optional[str] = None,
|
||||||
interval_minutes: int = 15
|
interval_minutes: int = 15,
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
# raw SQL for time bucketing as it is very engine specific
|
bucket_seconds = interval_minutes * 60
|
||||||
_where_stmt = select(Log)
|
bucket_expr = literal_column(
|
||||||
_where_stmt = self._apply_filters(_where_stmt, search, start_time, end_time)
|
f"datetime((strftime('%s', timestamp) / {bucket_seconds}) * {bucket_seconds}, 'unixepoch')"
|
||||||
|
).label("bucket_time")
|
||||||
|
|
||||||
# Extract WHERE clause from compiled statement
|
statement = select(bucket_expr, func.count().label("count")).select_from(Log)
|
||||||
# For simplicity in this migration, we'll use a semi-raw approach for the complex histogram query
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
# but bind parameters from the filtered statement
|
statement = statement.group_by(literal_column("bucket_time")).order_by(
|
||||||
|
literal_column("bucket_time")
|
||||||
# SQLite specific bucket logic
|
)
|
||||||
bucket_expr = f"(strftime('%s', timestamp) / {interval_minutes * 60}) * {interval_minutes * 60}"
|
|
||||||
|
|
||||||
# We'll use session.execute with a text statement for the grouping but reuse the WHERE logic if possible
|
|
||||||
# Or just build the query fully
|
|
||||||
|
|
||||||
where_clause, params = self._build_where_clause_legacy(search, start_time, end_time)
|
|
||||||
|
|
||||||
query = f"""
|
|
||||||
SELECT
|
|
||||||
datetime({bucket_expr}, 'unixepoch') as bucket_time,
|
|
||||||
COUNT(*) as count
|
|
||||||
FROM logs
|
|
||||||
{where_clause}
|
|
||||||
GROUP BY bucket_time
|
|
||||||
ORDER BY bucket_time ASC
|
|
||||||
"""
|
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
results = await session.execute(text(query), params)
|
results = await session.execute(statement)
|
||||||
return [{"time": r[0], "count": r[1]} for r in results.all()]
|
return [{"time": r[0], "count": r[1]} for r in results.all()]
|
||||||
|
|
||||||
def _build_where_clause_legacy(self, search, start_time, end_time):
|
|
||||||
# Re-using the logic from the previous iteration for the raw query part
|
|
||||||
import shlex
|
|
||||||
import re
|
|
||||||
where_clauses = []
|
|
||||||
params = {}
|
|
||||||
|
|
||||||
if start_time:
|
|
||||||
where_clauses.append("timestamp >= :start_time")
|
|
||||||
params["start_time"] = start_time
|
|
||||||
if end_time:
|
|
||||||
where_clauses.append("timestamp <= :end_time")
|
|
||||||
params["end_time"] = end_time
|
|
||||||
|
|
||||||
if search:
|
|
||||||
try: tokens = shlex.split(search)
|
|
||||||
except: tokens = search.split(" ")
|
|
||||||
core_fields = {"decky": "decky", "service": "service", "event": "event_type", "attacker": "attacker_ip"}
|
|
||||||
for i, token in enumerate(tokens):
|
|
||||||
if ":" in token:
|
|
||||||
k, v = token.split(":", 1)
|
|
||||||
if k in core_fields:
|
|
||||||
where_clauses.append(f"{core_fields[k]} = :val_{i}")
|
|
||||||
params[f"val_{i}"] = v
|
|
||||||
else:
|
|
||||||
ks = re.sub(r'[^a-zA-Z0-9_]', '', k)
|
|
||||||
where_clauses.append(f"json_extract(fields, '$.{ks}') = :val_{i}")
|
|
||||||
params[f"val_{i}"] = v
|
|
||||||
else:
|
|
||||||
where_clauses.append(f"(raw_line LIKE :lk_{i} OR decky LIKE :lk_{i} OR service LIKE :lk_{i} OR attacker_ip LIKE :lk_{i})")
|
|
||||||
params[f"lk_{i}"] = f"%{token}%"
|
|
||||||
|
|
||||||
where = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
|
|
||||||
return where, params
|
|
||||||
|
|
||||||
async def get_stats_summary(self) -> dict[str, Any]:
|
async def get_stats_summary(self) -> dict[str, Any]:
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
total_logs = (await session.execute(select(func.count()).select_from(Log))).scalar() or 0
|
total_logs = (
|
||||||
unique_attackers = (await session.execute(select(func.count(func.distinct(Log.attacker_ip))))).scalar() or 0
|
await session.execute(select(func.count()).select_from(Log))
|
||||||
active_deckies = (await session.execute(select(func.count(func.distinct(Log.decky))))).scalar() or 0
|
).scalar() or 0
|
||||||
|
unique_attackers = (
|
||||||
|
await session.execute(
|
||||||
|
select(func.count(func.distinct(Log.attacker_ip)))
|
||||||
|
)
|
||||||
|
).scalar() or 0
|
||||||
|
active_deckies = (
|
||||||
|
await session.execute(
|
||||||
|
select(func.count(func.distinct(Log.decky)))
|
||||||
|
)
|
||||||
|
).scalar() or 0
|
||||||
|
|
||||||
_state = load_state()
|
_state = await asyncio.to_thread(load_state)
|
||||||
deployed_deckies = len(_state[0].deckies) if _state else 0
|
deployed_deckies = len(_state[0].deckies) if _state else 0
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_logs": total_logs,
|
"total_logs": total_logs,
|
||||||
"unique_attackers": unique_attackers,
|
"unique_attackers": unique_attackers,
|
||||||
"active_deckies": active_deckies,
|
"active_deckies": active_deckies,
|
||||||
"deployed_deckies": deployed_deckies
|
"deployed_deckies": deployed_deckies,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_deckies(self) -> List[dict]:
|
async def get_deckies(self) -> List[dict]:
|
||||||
_state = load_state()
|
_state = await asyncio.to_thread(load_state)
|
||||||
return [_d.model_dump() for _d in _state[0].deckies] if _state else []
|
return [_d.model_dump() for _d in _state[0].deckies] if _state else []
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ users
|
||||||
|
|
||||||
async def get_user_by_username(self, username: str) -> Optional[dict]:
|
async def get_user_by_username(self, username: str) -> Optional[dict]:
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
statement = select(User).where(User.username == username)
|
result = await session.execute(
|
||||||
results = await session.execute(statement)
|
select(User).where(User.username == username)
|
||||||
user = results.scalar_one_or_none()
|
)
|
||||||
return user.dict() if user else None
|
user = result.scalar_one_or_none()
|
||||||
|
return user.model_dump() if user else None
|
||||||
|
|
||||||
async def get_user_by_uuid(self, uuid: str) -> Optional[dict]:
|
async def get_user_by_uuid(self, uuid: str) -> Optional[dict]:
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
statement = select(User).where(User.uuid == uuid)
|
result = await session.execute(
|
||||||
results = await session.execute(statement)
|
select(User).where(User.uuid == uuid)
|
||||||
user = results.scalar_one_or_none()
|
)
|
||||||
return user.dict() if user else None
|
user = result.scalar_one_or_none()
|
||||||
|
return user.model_dump() if user else None
|
||||||
|
|
||||||
async def create_user(self, user_data: dict[str, Any]) -> None:
|
async def create_user(self, user_data: dict[str, Any]) -> None:
|
||||||
user = User(**user_data)
|
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
session.add(user)
|
session.add(User(**user_data))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None:
|
async def update_user_password(
|
||||||
|
self, uuid: str, password_hash: str, must_change_password: bool = False
|
||||||
|
) -> None:
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
statement = update(User).where(User.uuid == uuid).values(
|
await session.execute(
|
||||||
password_hash=password_hash,
|
update(User)
|
||||||
must_change_password=must_change_password
|
.where(User.uuid == uuid)
|
||||||
|
.values(
|
||||||
|
password_hash=password_hash,
|
||||||
|
must_change_password=must_change_password,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
await session.execute(statement)
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------- bounties
|
||||||
|
|
||||||
async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
|
async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
|
||||||
data = bounty_data.copy()
|
data = bounty_data.copy()
|
||||||
if "payload" in data and isinstance(data["payload"], dict):
|
if "payload" in data and isinstance(data["payload"], dict):
|
||||||
data["payload"] = json.dumps(data["payload"])
|
data["payload"] = json.dumps(data["payload"])
|
||||||
|
|
||||||
bounty = Bounty(**data)
|
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
session.add(bounty)
|
session.add(Bounty(**data))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
def _apply_bounty_filters(self, statement, bounty_type: Optional[str], search: Optional[str]):
|
def _apply_bounty_filters(self, statement, bounty_type: Optional[str], search: Optional[str]):
|
||||||
@@ -316,7 +314,7 @@ class SQLiteRepository(BaseRepository):
|
|||||||
Bounty.decky.like(lk),
|
Bounty.decky.like(lk),
|
||||||
Bounty.service.like(lk),
|
Bounty.service.like(lk),
|
||||||
Bounty.attacker_ip.like(lk),
|
Bounty.attacker_ip.like(lk),
|
||||||
Bounty.payload.like(lk)
|
Bounty.payload.like(lk),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return statement
|
return statement
|
||||||
@@ -326,23 +324,31 @@ class SQLiteRepository(BaseRepository):
|
|||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
bounty_type: Optional[str] = None,
|
bounty_type: Optional[str] = None,
|
||||||
search: Optional[str] = None
|
search: Optional[str] = None,
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
statement = select(Bounty).order_by(desc(Bounty.timestamp)).offset(offset).limit(limit)
|
statement = (
|
||||||
|
select(Bounty)
|
||||||
|
.order_by(desc(Bounty.timestamp))
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
results = await session.execute(statement)
|
results = await session.execute(statement)
|
||||||
items = results.scalars().all()
|
|
||||||
final = []
|
final = []
|
||||||
for item in items:
|
for item in results.scalars().all():
|
||||||
d = item.dict()
|
d = item.model_dump()
|
||||||
try: d["payload"] = json.loads(d["payload"])
|
try:
|
||||||
except: pass
|
d["payload"] = json.loads(d["payload"])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
final.append(d)
|
final.append(d)
|
||||||
return final
|
return final
|
||||||
|
|
||||||
async def get_total_bounties(self, bounty_type: Optional[str] = None, search: Optional[str] = None) -> int:
|
async def get_total_bounties(
|
||||||
|
self, bounty_type: Optional[str] = None, search: Optional[str] = None
|
||||||
|
) -> int:
|
||||||
statement = select(func.count()).select_from(Bounty)
|
statement = select(func.count()).select_from(Bounty)
|
||||||
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user