Files
DECNET/decnet/profiler/worker.py
anti bcf460d2a5 feat(profiler): write ASN + AS name onto attacker rows
Adds asn (int), as_name (varchar 128), asn_source (varchar 16) to
the Attacker SQLModel — direct columns, no _migrate_* helper per
feedback_no_new_migrations_prev1.

Profiler worker now calls decnet.asn.enrich_ip alongside the existing
geoip enrich_ip; both feed the upsert payload. Failure is total — if
either lookup throws or the IP is private/unannounced, the field stays
None and the row still writes.

Both lookups are independent: a CGNAT address can have a country (RIR
allocation) but no ASN (no BGP origin), and vice-versa for unrouted
RIR-allocated space. Storing them separately preserves that signal.
2026-04-25 04:01:28 -04:00

432 lines
16 KiB
Python

"""
Attacker profile builder — incremental background worker.
Maintains a persistent CorrelationEngine and a log-ID cursor across cycles.
On cold start (first cycle or process restart), performs one full build from
all stored logs. Subsequent cycles fetch only new logs via the cursor,
ingest them into the existing engine, and rebuild profiles for affected IPs
only.
Complexity per cycle: O(new_logs + affected_ips) instead of O(total_logs²).
"""
from __future__ import annotations
import asyncio
import contextlib
import json
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Callable
from decnet.bus import topics as _topics
from decnet.bus.factory import get_bus
from decnet.bus.publish import (
make_thread_safe_publisher,
run_control_listener,
run_health_heartbeat,
)
from decnet.correlation.engine import CorrelationEngine
from decnet.correlation.parser import LogEvent
from decnet.asn import enrich_ip as enrich_ip_asn
from decnet.geoip import enrich_ip
from decnet.geoip.ptr import resolve_ptr_record
from decnet.logging import get_logger
from decnet.profiler.behavioral import build_behavior_record
from decnet.telemetry import traced as _traced, get_tracer as _get_tracer
from decnet.web.db.repository import BaseRepository
logger = get_logger("attacker_worker")
_BATCH_SIZE = 500
_STATE_KEY = "attacker_worker_cursor"
# Event types that indicate active command/query execution — the
# shell-family subset of INTERACTION_EVENT_TYPES in
# decnet/correlation/event_kinds.py. Kept here because this set is a
# stricter filter (commands that carry text to extract, vs. interactions
# like RCPT TO or file upload that don't). A test in
# tests/profiler/ asserts it's a subset of the canonical interaction
# set so they can't drift.
_COMMAND_EVENT_TYPES = frozenset({
"command", "exec", "query", "input", "shell_input",
"execute", "run", "sql_query", "redis_command",
})
# Fields that carry the executed command/query text
_COMMAND_FIELDS = ("command", "query", "input", "line", "sql", "cmd")
# SMTP events that carry a recipient email address. `rcpt_to` fires once per
# accepted RCPT (open-relay mode), `rcpt_denied` once per denied RCPT
# (harvester mode). `message_accepted` carries the comma-joined rcpt list
# on the final DATA commit — covered for replay safety, though every
# address it contains already arrived via `rcpt_to` earlier in the session.
_SMTP_RCPT_EVENTS = frozenset({"rcpt_to", "rcpt_denied", "message_accepted"})
# Pseudo-TLDs we never want to report on: the RFC 6761 special-use names
# plus common lab-only values. Matching happens on the *last* label so
# `foo.example.com` is filtered but `example.corp` is not.
_BLOCKED_TLDS = frozenset({"invalid", "test", "localhost", "local", "example"})
@dataclass
class _WorkerState:
engine: CorrelationEngine = field(default_factory=CorrelationEngine)
last_log_id: int = 0
initialized: bool = False
# Optional bus hook — fires ``("scored", payload)`` per profile upsert.
# None when the bus is disabled or unreachable.
publish_attacker: Callable[[str, dict[str, Any]], None] | None = None
# Set of IPs we've already tried to PTR-resolve in this worker's
# lifetime. Bounds retry to once per worker boot so a persistently
# NXDOMAIN-returning IP doesn't burn 2s of tick time on every cycle.
ptr_attempted: set[str] = field(default_factory=set)
async def attacker_profile_worker(repo: BaseRepository, *, interval: int = 30) -> None:
"""Periodically updates the Attacker table incrementally. Designed to run as an asyncio Task."""
logger.info("attacker profile worker started interval=%ds", interval)
# Optional bus wiring — correlator-family publishes ride on the profiler
# worker because CorrelationEngine lives inside it. If the bus is off or
# unreachable the engine runs with publish_fn=None and downstream degrades
# to DB-only.
bus = None
try:
bus = get_bus(client_name="profiler")
await bus.connect()
except Exception as exc:
logger.warning("profiler: bus unavailable, continuing without publish: %s", exc)
bus = None
loop = asyncio.get_running_loop()
raw_publish = make_thread_safe_publisher(bus, loop) if bus is not None else None
def _publish_attacker(event_type: str, payload: dict[str, Any]) -> None:
if raw_publish is None:
return
raw_publish(_topics.attacker(event_type), payload, event_type)
state = _WorkerState(
engine=CorrelationEngine(publish_fn=_publish_attacker),
publish_attacker=_publish_attacker,
)
_saved_cursor = await repo.get_state(_STATE_KEY)
if _saved_cursor:
state.last_log_id = _saved_cursor.get("last_log_id", 0)
state.initialized = True
logger.info("attacker worker: resumed from cursor last_log_id=%d", state.last_log_id)
# Workers panel wiring: heartbeat + bus-driven stop. Main loop is
# pure asyncio sleep/await, so an event-based control listener
# drops in cleanly without a SIGTERM self-signal.
shutdown = asyncio.Event()
heartbeat_task = asyncio.create_task(run_health_heartbeat(bus, "profiler"))
control_task = asyncio.create_task(
run_control_listener(bus, "profiler", shutdown),
)
try:
while not shutdown.is_set():
try:
await asyncio.wait_for(shutdown.wait(), timeout=interval)
except asyncio.TimeoutError:
pass # normal tick
if shutdown.is_set():
break
try:
await _incremental_update(repo, state)
except Exception as exc:
logger.error("attacker worker: update failed: %s", exc)
finally:
for t in (heartbeat_task, control_task):
t.cancel()
with contextlib.suppress(Exception, asyncio.CancelledError):
await t
if bus is not None:
with contextlib.suppress(Exception):
await bus.close()
@_traced("profiler.incremental_update")
async def _incremental_update(repo: BaseRepository, state: _WorkerState) -> None:
was_cold = not state.initialized
affected_ips: set[str] = set()
while True:
batch = await repo.get_logs_after_id(state.last_log_id, limit=_BATCH_SIZE)
if not batch:
break
for row in batch:
event = state.engine.ingest(row["raw_line"])
if event and event.attacker_ip:
affected_ips.add(event.attacker_ip)
state.last_log_id = row["id"]
await asyncio.sleep(0) # yield to event loop after each batch
if len(batch) < _BATCH_SIZE:
break
state.initialized = True
if not affected_ips:
await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id})
return
await _update_profiles(repo, state, affected_ips)
await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id})
if was_cold:
logger.info("attacker worker: cold start rebuilt %d profiles", len(affected_ips))
else:
logger.info("attacker worker: updated %d profiles (incremental)", len(affected_ips))
_PTR_CONCURRENCY = 10
async def _resolve_ptrs_for(ips: list[str]) -> dict[str, Any]:
"""Resolve PTR for each *ip* concurrently, bounded.
Returns ``{ip: ptr_or_None}`` for every input. Uses an asyncio
semaphore to cap parallel lookups — cold-start could see hundreds
of fresh IPs and we don't want to hammer the OS resolver.
"""
if not ips:
return {}
sem = asyncio.Semaphore(_PTR_CONCURRENCY)
async def _one(ip: str) -> tuple[str, Any]:
async with sem:
return ip, await resolve_ptr_record(ip)
results = await asyncio.gather(*(_one(ip) for ip in ips))
return dict(results)
@_traced("profiler.update_profiles")
async def _update_profiles(
repo: BaseRepository,
state: _WorkerState,
ips: set[str],
) -> None:
traversal_map = {t.attacker_ip: t for t in state.engine.traversals(min_deckies=2)}
bounties_map = await repo.get_bounties_for_ips(ips)
# PTR resolution: one shot per IP per worker lifetime. OS resolver
# caches, so re-runs on worker restart hit cache instantly for IPs
# resolved recently; only never-seen addresses pay the 2s ceiling.
fresh = [ip for ip in ips if ip not in state.ptr_attempted]
for ip in fresh:
state.ptr_attempted.add(ip)
ptrs = await _resolve_ptrs_for(fresh)
_tracer = _get_tracer("profiler")
for ip in ips:
events = state.engine._events.get(ip, [])
if not events:
continue
with _tracer.start_as_current_span("profiler.process_ip") as _span:
_span.set_attribute("attacker_ip", ip)
_span.set_attribute("event_count", len(events))
traversal = traversal_map.get(ip)
bounties = bounties_map.get(ip, [])
commands = _extract_commands_from_events(events)
if ip in ptrs:
record = _build_record(
ip, events, traversal, bounties, commands,
ptr_record=ptrs[ip],
)
else:
# Not in ptrs → already attempted in a prior cycle → skip
# kwarg so upsert preserves whatever's stored.
record = _build_record(ip, events, traversal, bounties, commands)
attacker_uuid = await repo.upsert_attacker(record)
_span.set_attribute("is_traversal", traversal is not None)
_span.set_attribute("bounty_count", len(bounties))
_span.set_attribute("command_count", len(commands))
if state.publish_attacker is not None:
try:
state.publish_attacker("scored", {
"attacker_ip": ip,
"event_count": record["event_count"],
"service_count": record["service_count"],
"decky_count": record["decky_count"],
"bounty_count": record["bounty_count"],
"credential_count": record["credential_count"],
"is_traversal": record["is_traversal"],
})
except Exception as exc:
logger.warning("attacker worker: scored publish failed for %s: %s", ip, exc)
# Behavioral / fingerprint rollup lives in a sibling table so failures
# here never block the core attacker profile upsert.
try:
behavior = build_behavior_record(events)
await repo.upsert_attacker_behavior(attacker_uuid, behavior)
except Exception as exc:
_span.record_exception(exc)
logger.error("attacker worker: behavior upsert failed for %s: %s", ip, exc)
# SMTP victim-domain tracking — extract domains from RCPT events
# and upsert one row per (attacker, domain) pair. Same
# soft-fail posture as the behavior rollup: errors here must
# not block the next attacker.
try:
for domain in _extract_smtp_domains(events):
await repo.increment_smtp_target(attacker_uuid, domain)
except Exception as exc:
_span.record_exception(exc)
logger.error("attacker worker: smtp target upsert failed for %s: %s", ip, exc)
_UNSET = object() # sentinel — distinguishes "not passed" from "None"
def _build_record(
ip: str,
events: list[LogEvent],
traversal: Any,
bounties: list[dict[str, Any]],
commands: list[dict[str, Any]],
*,
ptr_record: Any = _UNSET,
) -> dict[str, Any]:
services = sorted({e.service for e in events})
deckies = (
traversal.deckies
if traversal
else _first_contact_deckies(events)
)
fingerprints = [b for b in bounties if b.get("bounty_type") == "fingerprint"]
credential_count = sum(1 for b in bounties if b.get("bounty_type") == "credential")
country_code, country_source = enrich_ip(ip)
asn, as_name, asn_source = enrich_ip_asn(ip)
record: dict[str, Any] = {
"ip": ip,
"first_seen": min(e.timestamp for e in events),
"last_seen": max(e.timestamp for e in events),
"event_count": len(events),
"service_count": len(services),
"decky_count": len({e.decky for e in events}),
"services": json.dumps(services),
"deckies": json.dumps(deckies),
"traversal_path": traversal.path if traversal else None,
"is_traversal": traversal is not None,
"bounty_count": len(bounties),
"credential_count": credential_count,
"fingerprints": json.dumps(fingerprints),
"commands": json.dumps(commands),
"country_code": country_code,
"country_source": country_source,
"asn": asn,
"as_name": as_name,
"asn_source": asn_source,
"updated_at": datetime.now(timezone.utc),
}
# ptr_record is omitted from the dict entirely when the caller didn't
# supply one — lets the upsert's attribute-merge preserve any value
# already stored on the row without us having to think about "None
# means preserve vs. overwrite".
if ptr_record is not _UNSET:
record["ptr_record"] = ptr_record
return record
def _first_contact_deckies(events: list[LogEvent]) -> list[str]:
"""Return unique deckies in first-contact order (for non-traversal attackers)."""
seen: list[str] = []
for e in sorted(events, key=lambda x: x.timestamp):
if e.decky not in seen:
seen.append(e.decky)
return seen
def _extract_commands_from_events(events: list[LogEvent]) -> list[dict[str, Any]]:
"""
Extract executed commands from LogEvent objects.
Works directly on LogEvent.fields (already a dict), so no JSON parsing needed.
"""
commands: list[dict[str, Any]] = []
for event in events:
if event.event_type not in _COMMAND_EVENT_TYPES:
continue
cmd_text: str | None = None
for key in _COMMAND_FIELDS:
val = event.fields.get(key)
if val:
cmd_text = str(val)
break
if not cmd_text:
continue
commands.append({
"service": event.service,
"decky": event.decky,
"command": cmd_text,
"timestamp": event.timestamp.isoformat(),
})
return commands
_SMTP_ADDR_RE = re.compile(r"<?([^\s<>@]+)@([A-Za-z0-9.-]+\.[A-Za-z]{2,})>?")
def _normalize_smtp_domain(raw: str) -> str | None:
"""Extract a lowercased domain from an envelope-address fragment.
Returns None when the input doesn't look like an email address or the
resulting TLD is on the blocklist. Local-parts (the bit before `@`)
are intentionally dropped — this table stores no user-identifying
data, only the targeted organisation's domain.
"""
if not raw:
return None
match = _SMTP_ADDR_RE.search(raw.strip())
if not match:
return None
domain = match.group(2).lower().strip(".")
if not domain:
return None
tld = domain.rsplit(".", 1)[-1]
if tld in _BLOCKED_TLDS:
return None
return domain
def _extract_smtp_domains(events: list[LogEvent]) -> set[str]:
"""Collect the set of victim domains an attacker targeted via SMTP.
Deduped at the attacker level — repeated hits on the same domain
within a single batch collapse to one upsert, and the per-row count
is bumped by ``increment_smtp_target`` on each call. The set return
type is intentional: we care about *which* domains were seen, not
the per-batch frequency (which the DB aggregates over time).
"""
domains: set[str] = set()
for event in events:
if event.service != "smtp" or event.event_type not in _SMTP_RCPT_EVENTS:
continue
if event.event_type == "message_accepted":
raw_list = event.fields.get("rcpt_to", "")
candidates = raw_list.split(",") if raw_list else []
else:
candidates = [event.fields.get("value", "")]
for candidate in candidates:
domain = _normalize_smtp_domain(candidate)
if domain:
domains.add(domain)
return domains