Credential capture runs before the profiler mints an Attacker, so Credential.attacker_uuid is nullable on write. The profiler now backfills the FK after each successful upsert_attacker. Soft-fail posture matches the surrounding behavior + smtp rollups so a backfill error never blocks the next attacker.
443 lines
17 KiB
Python
443 lines
17 KiB
Python
"""
|
|
Attacker profile builder — incremental background worker.
|
|
|
|
Maintains a persistent CorrelationEngine and a log-ID cursor across cycles.
|
|
On cold start (first cycle or process restart), performs one full build from
|
|
all stored logs. Subsequent cycles fetch only new logs via the cursor,
|
|
ingest them into the existing engine, and rebuild profiles for affected IPs
|
|
only.
|
|
|
|
Complexity per cycle: O(new_logs + affected_ips) instead of O(total_logs²).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import json
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Callable
|
|
|
|
from decnet.bus import topics as _topics
|
|
from decnet.bus.factory import get_bus
|
|
from decnet.bus.publish import (
|
|
make_thread_safe_publisher,
|
|
run_control_listener,
|
|
run_health_heartbeat,
|
|
)
|
|
from decnet.correlation.engine import CorrelationEngine
|
|
from decnet.correlation.parser import LogEvent
|
|
from decnet.asn import enrich_ip as enrich_ip_asn
|
|
from decnet.geoip import enrich_ip
|
|
from decnet.geoip.ptr import resolve_ptr_record
|
|
from decnet.logging import get_logger
|
|
from decnet.profiler.behavioral import build_behavior_record
|
|
from decnet.telemetry import traced as _traced, get_tracer as _get_tracer
|
|
from decnet.web.db.repository import BaseRepository
|
|
|
|
logger = get_logger("attacker_worker")
|
|
|
|
_BATCH_SIZE = 500
|
|
_STATE_KEY = "attacker_worker_cursor"
|
|
|
|
# Event types that indicate active command/query execution — the
|
|
# shell-family subset of INTERACTION_EVENT_TYPES in
|
|
# decnet/correlation/event_kinds.py. Kept here because this set is a
|
|
# stricter filter (commands that carry text to extract, vs. interactions
|
|
# like RCPT TO or file upload that don't). A test in
|
|
# tests/profiler/ asserts it's a subset of the canonical interaction
|
|
# set so they can't drift.
|
|
_COMMAND_EVENT_TYPES = frozenset({
|
|
"command", "exec", "query", "input", "shell_input",
|
|
"execute", "run", "sql_query", "redis_command",
|
|
})
|
|
|
|
# Fields that carry the executed command/query text
|
|
_COMMAND_FIELDS = ("command", "query", "input", "line", "sql", "cmd")
|
|
|
|
# SMTP events that carry a recipient email address. `rcpt_to` fires once per
|
|
# accepted RCPT (open-relay mode), `rcpt_denied` once per denied RCPT
|
|
# (harvester mode). `message_accepted` carries the comma-joined rcpt list
|
|
# on the final DATA commit — covered for replay safety, though every
|
|
# address it contains already arrived via `rcpt_to` earlier in the session.
|
|
_SMTP_RCPT_EVENTS = frozenset({"rcpt_to", "rcpt_denied", "message_accepted"})
|
|
|
|
# Pseudo-TLDs we never want to report on: the RFC 6761 special-use names
|
|
# plus common lab-only values. Matching happens on the *last* label so
|
|
# `foo.example.com` is filtered but `example.corp` is not.
|
|
_BLOCKED_TLDS = frozenset({"invalid", "test", "localhost", "local", "example"})
|
|
|
|
|
|
@dataclass
|
|
class _WorkerState:
|
|
engine: CorrelationEngine = field(default_factory=CorrelationEngine)
|
|
last_log_id: int = 0
|
|
initialized: bool = False
|
|
# Optional bus hook — fires ``("scored", payload)`` per profile upsert.
|
|
# None when the bus is disabled or unreachable.
|
|
publish_attacker: Callable[[str, dict[str, Any]], None] | None = None
|
|
# Set of IPs we've already tried to PTR-resolve in this worker's
|
|
# lifetime. Bounds retry to once per worker boot so a persistently
|
|
# NXDOMAIN-returning IP doesn't burn 2s of tick time on every cycle.
|
|
ptr_attempted: set[str] = field(default_factory=set)
|
|
|
|
|
|
async def attacker_profile_worker(repo: BaseRepository, *, interval: int = 30) -> None:
|
|
"""Periodically updates the Attacker table incrementally. Designed to run as an asyncio Task."""
|
|
logger.info("attacker profile worker started interval=%ds", interval)
|
|
|
|
# Optional bus wiring — correlator-family publishes ride on the profiler
|
|
# worker because CorrelationEngine lives inside it. If the bus is off or
|
|
# unreachable the engine runs with publish_fn=None and downstream degrades
|
|
# to DB-only.
|
|
bus = None
|
|
try:
|
|
bus = get_bus(client_name="profiler")
|
|
await bus.connect()
|
|
except Exception as exc:
|
|
logger.warning("profiler: bus unavailable, continuing without publish: %s", exc)
|
|
bus = None
|
|
|
|
loop = asyncio.get_running_loop()
|
|
raw_publish = make_thread_safe_publisher(bus, loop) if bus is not None else None
|
|
|
|
def _publish_attacker(event_type: str, payload: dict[str, Any]) -> None:
|
|
if raw_publish is None:
|
|
return
|
|
raw_publish(_topics.attacker(event_type), payload, event_type)
|
|
|
|
state = _WorkerState(
|
|
engine=CorrelationEngine(publish_fn=_publish_attacker),
|
|
publish_attacker=_publish_attacker,
|
|
)
|
|
_saved_cursor = await repo.get_state(_STATE_KEY)
|
|
if _saved_cursor:
|
|
state.last_log_id = _saved_cursor.get("last_log_id", 0)
|
|
state.initialized = True
|
|
logger.info("attacker worker: resumed from cursor last_log_id=%d", state.last_log_id)
|
|
|
|
# Workers panel wiring: heartbeat + bus-driven stop. Main loop is
|
|
# pure asyncio sleep/await, so an event-based control listener
|
|
# drops in cleanly without a SIGTERM self-signal.
|
|
shutdown = asyncio.Event()
|
|
heartbeat_task = asyncio.create_task(run_health_heartbeat(bus, "profiler"))
|
|
control_task = asyncio.create_task(
|
|
run_control_listener(bus, "profiler", shutdown),
|
|
)
|
|
try:
|
|
while not shutdown.is_set():
|
|
try:
|
|
await asyncio.wait_for(shutdown.wait(), timeout=interval)
|
|
except asyncio.TimeoutError:
|
|
pass # normal tick
|
|
if shutdown.is_set():
|
|
break
|
|
try:
|
|
await _incremental_update(repo, state)
|
|
except Exception as exc:
|
|
logger.error("attacker worker: update failed: %s", exc)
|
|
finally:
|
|
for t in (heartbeat_task, control_task):
|
|
t.cancel()
|
|
with contextlib.suppress(Exception, asyncio.CancelledError):
|
|
await t
|
|
if bus is not None:
|
|
with contextlib.suppress(Exception):
|
|
await bus.close()
|
|
|
|
|
|
@_traced("profiler.incremental_update")
|
|
async def _incremental_update(repo: BaseRepository, state: _WorkerState) -> None:
|
|
was_cold = not state.initialized
|
|
affected_ips: set[str] = set()
|
|
|
|
while True:
|
|
batch = await repo.get_logs_after_id(state.last_log_id, limit=_BATCH_SIZE)
|
|
if not batch:
|
|
break
|
|
|
|
for row in batch:
|
|
event = state.engine.ingest(row["raw_line"])
|
|
if event and event.attacker_ip:
|
|
affected_ips.add(event.attacker_ip)
|
|
state.last_log_id = row["id"]
|
|
|
|
await asyncio.sleep(0) # yield to event loop after each batch
|
|
|
|
if len(batch) < _BATCH_SIZE:
|
|
break
|
|
|
|
state.initialized = True
|
|
|
|
if not affected_ips:
|
|
await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id})
|
|
return
|
|
|
|
await _update_profiles(repo, state, affected_ips)
|
|
await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id})
|
|
|
|
if was_cold:
|
|
logger.info("attacker worker: cold start rebuilt %d profiles", len(affected_ips))
|
|
else:
|
|
logger.info("attacker worker: updated %d profiles (incremental)", len(affected_ips))
|
|
|
|
|
|
_PTR_CONCURRENCY = 10
|
|
|
|
|
|
async def _resolve_ptrs_for(ips: list[str]) -> dict[str, Any]:
|
|
"""Resolve PTR for each *ip* concurrently, bounded.
|
|
|
|
Returns ``{ip: ptr_or_None}`` for every input. Uses an asyncio
|
|
semaphore to cap parallel lookups — cold-start could see hundreds
|
|
of fresh IPs and we don't want to hammer the OS resolver.
|
|
"""
|
|
if not ips:
|
|
return {}
|
|
sem = asyncio.Semaphore(_PTR_CONCURRENCY)
|
|
|
|
async def _one(ip: str) -> tuple[str, Any]:
|
|
async with sem:
|
|
return ip, await resolve_ptr_record(ip)
|
|
|
|
results = await asyncio.gather(*(_one(ip) for ip in ips))
|
|
return dict(results)
|
|
|
|
|
|
@_traced("profiler.update_profiles")
|
|
async def _update_profiles(
|
|
repo: BaseRepository,
|
|
state: _WorkerState,
|
|
ips: set[str],
|
|
) -> None:
|
|
traversal_map = {t.attacker_ip: t for t in state.engine.traversals(min_deckies=2)}
|
|
bounties_map = await repo.get_bounties_for_ips(ips)
|
|
|
|
# PTR resolution: one shot per IP per worker lifetime. OS resolver
|
|
# caches, so re-runs on worker restart hit cache instantly for IPs
|
|
# resolved recently; only never-seen addresses pay the 2s ceiling.
|
|
fresh = [ip for ip in ips if ip not in state.ptr_attempted]
|
|
for ip in fresh:
|
|
state.ptr_attempted.add(ip)
|
|
ptrs = await _resolve_ptrs_for(fresh)
|
|
|
|
_tracer = _get_tracer("profiler")
|
|
for ip in ips:
|
|
events = state.engine._events.get(ip, [])
|
|
if not events:
|
|
continue
|
|
|
|
with _tracer.start_as_current_span("profiler.process_ip") as _span:
|
|
_span.set_attribute("attacker_ip", ip)
|
|
_span.set_attribute("event_count", len(events))
|
|
|
|
traversal = traversal_map.get(ip)
|
|
bounties = bounties_map.get(ip, [])
|
|
commands = _extract_commands_from_events(events)
|
|
|
|
if ip in ptrs:
|
|
record = _build_record(
|
|
ip, events, traversal, bounties, commands,
|
|
ptr_record=ptrs[ip],
|
|
)
|
|
else:
|
|
# Not in ptrs → already attempted in a prior cycle → skip
|
|
# kwarg so upsert preserves whatever's stored.
|
|
record = _build_record(ip, events, traversal, bounties, commands)
|
|
attacker_uuid = await repo.upsert_attacker(record)
|
|
|
|
# Backfill Credential.attacker_uuid for every credential row
|
|
# captured before the profiler had minted this Attacker. The
|
|
# capture path runs before the profiler — coupling them would
|
|
# create a chicken-and-egg ordering bug. Soft-fail so a backfill
|
|
# error never blocks the next attacker.
|
|
try:
|
|
await repo.update_credential_attacker_uuid(ip, attacker_uuid)
|
|
except Exception as exc:
|
|
_span.record_exception(exc)
|
|
logger.error("attacker worker: credential backfill failed for %s: %s", ip, exc)
|
|
|
|
_span.set_attribute("is_traversal", traversal is not None)
|
|
_span.set_attribute("bounty_count", len(bounties))
|
|
_span.set_attribute("command_count", len(commands))
|
|
|
|
if state.publish_attacker is not None:
|
|
try:
|
|
state.publish_attacker("scored", {
|
|
"attacker_ip": ip,
|
|
"event_count": record["event_count"],
|
|
"service_count": record["service_count"],
|
|
"decky_count": record["decky_count"],
|
|
"bounty_count": record["bounty_count"],
|
|
"credential_count": record["credential_count"],
|
|
"is_traversal": record["is_traversal"],
|
|
})
|
|
except Exception as exc:
|
|
logger.warning("attacker worker: scored publish failed for %s: %s", ip, exc)
|
|
|
|
# Behavioral / fingerprint rollup lives in a sibling table so failures
|
|
# here never block the core attacker profile upsert.
|
|
try:
|
|
behavior = build_behavior_record(events)
|
|
await repo.upsert_attacker_behavior(attacker_uuid, behavior)
|
|
except Exception as exc:
|
|
_span.record_exception(exc)
|
|
logger.error("attacker worker: behavior upsert failed for %s: %s", ip, exc)
|
|
|
|
# SMTP victim-domain tracking — extract domains from RCPT events
|
|
# and upsert one row per (attacker, domain) pair. Same
|
|
# soft-fail posture as the behavior rollup: errors here must
|
|
# not block the next attacker.
|
|
try:
|
|
for domain in _extract_smtp_domains(events):
|
|
await repo.increment_smtp_target(attacker_uuid, domain)
|
|
except Exception as exc:
|
|
_span.record_exception(exc)
|
|
logger.error("attacker worker: smtp target upsert failed for %s: %s", ip, exc)
|
|
|
|
|
|
_UNSET = object() # sentinel — distinguishes "not passed" from "None"
|
|
|
|
|
|
def _build_record(
|
|
ip: str,
|
|
events: list[LogEvent],
|
|
traversal: Any,
|
|
bounties: list[dict[str, Any]],
|
|
commands: list[dict[str, Any]],
|
|
*,
|
|
ptr_record: Any = _UNSET,
|
|
) -> dict[str, Any]:
|
|
services = sorted({e.service for e in events})
|
|
deckies = (
|
|
traversal.deckies
|
|
if traversal
|
|
else _first_contact_deckies(events)
|
|
)
|
|
fingerprints = [b for b in bounties if b.get("bounty_type") == "fingerprint"]
|
|
credential_count = sum(1 for b in bounties if b.get("bounty_type") == "credential")
|
|
country_code, country_source = enrich_ip(ip)
|
|
asn, as_name, asn_source = enrich_ip_asn(ip)
|
|
|
|
record: dict[str, Any] = {
|
|
"ip": ip,
|
|
"first_seen": min(e.timestamp for e in events),
|
|
"last_seen": max(e.timestamp for e in events),
|
|
"event_count": len(events),
|
|
"service_count": len(services),
|
|
"decky_count": len({e.decky for e in events}),
|
|
"services": json.dumps(services),
|
|
"deckies": json.dumps(deckies),
|
|
"traversal_path": traversal.path if traversal else None,
|
|
"is_traversal": traversal is not None,
|
|
"bounty_count": len(bounties),
|
|
"credential_count": credential_count,
|
|
"fingerprints": json.dumps(fingerprints),
|
|
"commands": json.dumps(commands),
|
|
"country_code": country_code,
|
|
"country_source": country_source,
|
|
"asn": asn,
|
|
"as_name": as_name,
|
|
"asn_source": asn_source,
|
|
"updated_at": datetime.now(timezone.utc),
|
|
}
|
|
# ptr_record is omitted from the dict entirely when the caller didn't
|
|
# supply one — lets the upsert's attribute-merge preserve any value
|
|
# already stored on the row without us having to think about "None
|
|
# means preserve vs. overwrite".
|
|
if ptr_record is not _UNSET:
|
|
record["ptr_record"] = ptr_record
|
|
return record
|
|
|
|
|
|
def _first_contact_deckies(events: list[LogEvent]) -> list[str]:
|
|
"""Return unique deckies in first-contact order (for non-traversal attackers)."""
|
|
seen: list[str] = []
|
|
for e in sorted(events, key=lambda x: x.timestamp):
|
|
if e.decky not in seen:
|
|
seen.append(e.decky)
|
|
return seen
|
|
|
|
|
|
def _extract_commands_from_events(events: list[LogEvent]) -> list[dict[str, Any]]:
|
|
"""
|
|
Extract executed commands from LogEvent objects.
|
|
|
|
Works directly on LogEvent.fields (already a dict), so no JSON parsing needed.
|
|
"""
|
|
commands: list[dict[str, Any]] = []
|
|
for event in events:
|
|
if event.event_type not in _COMMAND_EVENT_TYPES:
|
|
continue
|
|
|
|
cmd_text: str | None = None
|
|
for key in _COMMAND_FIELDS:
|
|
val = event.fields.get(key)
|
|
if val:
|
|
cmd_text = str(val)
|
|
break
|
|
|
|
if not cmd_text:
|
|
continue
|
|
|
|
commands.append({
|
|
"service": event.service,
|
|
"decky": event.decky,
|
|
"command": cmd_text,
|
|
"timestamp": event.timestamp.isoformat(),
|
|
})
|
|
|
|
return commands
|
|
|
|
|
|
_SMTP_ADDR_RE = re.compile(r"<?([^\s<>@]+)@([A-Za-z0-9.-]+\.[A-Za-z]{2,})>?")
|
|
|
|
|
|
def _normalize_smtp_domain(raw: str) -> str | None:
|
|
"""Extract a lowercased domain from an envelope-address fragment.
|
|
|
|
Returns None when the input doesn't look like an email address or the
|
|
resulting TLD is on the blocklist. Local-parts (the bit before `@`)
|
|
are intentionally dropped — this table stores no user-identifying
|
|
data, only the targeted organisation's domain.
|
|
"""
|
|
if not raw:
|
|
return None
|
|
match = _SMTP_ADDR_RE.search(raw.strip())
|
|
if not match:
|
|
return None
|
|
domain = match.group(2).lower().strip(".")
|
|
if not domain:
|
|
return None
|
|
tld = domain.rsplit(".", 1)[-1]
|
|
if tld in _BLOCKED_TLDS:
|
|
return None
|
|
return domain
|
|
|
|
|
|
def _extract_smtp_domains(events: list[LogEvent]) -> set[str]:
|
|
"""Collect the set of victim domains an attacker targeted via SMTP.
|
|
|
|
Deduped at the attacker level — repeated hits on the same domain
|
|
within a single batch collapse to one upsert, and the per-row count
|
|
is bumped by ``increment_smtp_target`` on each call. The set return
|
|
type is intentional: we care about *which* domains were seen, not
|
|
the per-batch frequency (which the DB aggregates over time).
|
|
"""
|
|
domains: set[str] = set()
|
|
for event in events:
|
|
if event.service != "smtp" or event.event_type not in _SMTP_RCPT_EVENTS:
|
|
continue
|
|
if event.event_type == "message_accepted":
|
|
raw_list = event.fields.get("rcpt_to", "")
|
|
candidates = raw_list.split(",") if raw_list else []
|
|
else:
|
|
candidates = [event.fields.get("value", "")]
|
|
for candidate in candidates:
|
|
domain = _normalize_smtp_domain(candidate)
|
|
if domain:
|
|
domains.add(domain)
|
|
return domains
|