diff --git a/decnet/collector/worker.py b/decnet/collector/worker.py index 32109aa..a6714bd 100644 --- a/decnet/collector/worker.py +++ b/decnet/collector/worker.py @@ -18,6 +18,7 @@ from pathlib import Path from typing import Any, Optional from decnet.logging import get_logger +from decnet.telemetry import traced as _traced logger = get_logger("collector") @@ -220,6 +221,7 @@ def _reopen_if_needed(path: Path, fh: Optional[Any]) -> Any: return open(path, "a", encoding="utf-8") +@_traced("collector.stream_container") def _stream_container(container_id: str, log_path: Path, json_path: Path) -> None: """Stream logs from one container and append to the host log files.""" import docker # type: ignore[import] diff --git a/decnet/engine/deployer.py b/decnet/engine/deployer.py index aa9252b..70c2357 100644 --- a/decnet/engine/deployer.py +++ b/decnet/engine/deployer.py @@ -12,6 +12,7 @@ from rich.console import Console from rich.table import Table from decnet.logging import get_logger +from decnet.telemetry import traced as _traced from decnet.config import DecnetConfig, clear_state, load_state, save_state from decnet.composer import write_compose from decnet.network import ( @@ -107,6 +108,7 @@ def _compose_with_retry( raise last_exc +@_traced("engine.deploy") def deploy(config: DecnetConfig, dry_run: bool = False, no_cache: bool = False, parallel: bool = False) -> None: log.info("deployment started n_deckies=%d interface=%s subnet=%s dry_run=%s", len(config.deckies), config.interface, config.subnet, dry_run) log.debug("deploy: deckies=%s", [d.name for d in config.deckies]) @@ -171,6 +173,7 @@ def deploy(config: DecnetConfig, dry_run: bool = False, no_cache: bool = False, _print_status(config) +@_traced("engine.teardown") def teardown(decky_id: str | None = None) -> None: log.info("teardown requested decky_id=%s", decky_id or "all") state = load_state() diff --git a/decnet/env.py b/decnet/env.py index f247352..f45d1d7 100644 --- a/decnet/env.py +++ b/decnet/env.py @@ -72,6 +72,11 @@ DECNET_ADMIN_USER: str = os.environ.get("DECNET_ADMIN_USER", "admin") DECNET_ADMIN_PASSWORD: str = os.environ.get("DECNET_ADMIN_PASSWORD", "admin") DECNET_DEVELOPER: bool = os.environ.get("DECNET_DEVELOPER", "False").lower() == "true" +# Tracing — set to "true" to enable OpenTelemetry distributed tracing. +# Separate from DECNET_DEVELOPER so tracing can be toggled independently. +DECNET_DEVELOPER_TRACING: bool = os.environ.get("DECNET_DEVELOPER_TRACING", "").lower() == "true" +DECNET_OTEL_ENDPOINT: str = os.environ.get("DECNET_OTEL_ENDPOINT", "http://localhost:4317") + # Database Options DECNET_DB_TYPE: str = os.environ.get("DECNET_DB_TYPE", "sqlite").lower() DECNET_DB_URL: Optional[str] = os.environ.get("DECNET_DB_URL") diff --git a/decnet/mutator/engine.py b/decnet/mutator/engine.py index 6ef916c..d011b19 100644 --- a/decnet/mutator/engine.py +++ b/decnet/mutator/engine.py @@ -15,6 +15,7 @@ from decnet.composer import write_compose from decnet.config import DeckyConfig, DecnetConfig from decnet.engine import _compose_with_retry from decnet.logging import get_logger +from decnet.telemetry import traced as _traced from pathlib import Path import anyio @@ -25,6 +26,7 @@ log = get_logger("mutator") console = Console() +@_traced("mutator.mutate_decky") async def mutate_decky(decky_name: str, repo: BaseRepository) -> bool: """ Perform an Intra-Archetype Shuffle for a specific decky. @@ -91,6 +93,7 @@ async def mutate_decky(decky_name: str, repo: BaseRepository) -> bool: return True +@_traced("mutator.mutate_all") async def mutate_all(repo: BaseRepository, force: bool = False) -> None: """ Check all deckies and mutate those that are due. diff --git a/decnet/prober/worker.py b/decnet/prober/worker.py index f9e0fce..48cc58e 100644 --- a/decnet/prober/worker.py +++ b/decnet/prober/worker.py @@ -30,6 +30,7 @@ from decnet.logging import get_logger from decnet.prober.hassh import hassh_server from decnet.prober.jarm import JARM_EMPTY_HASH, jarm_hash from decnet.prober.tcpfp import tcp_fingerprint +from decnet.telemetry import traced as _traced logger = get_logger("prober") @@ -219,6 +220,7 @@ def _discover_attackers(json_path: Path, position: int) -> tuple[set[str], int]: # ─── Probe cycle ───────────────────────────────────────────────────────────── +@_traced("prober.probe_cycle") def _probe_cycle( targets: set[str], probed: dict[str, dict[str, set[int]]], @@ -255,6 +257,7 @@ def _probe_cycle( _tcpfp_phase(ip, ip_probed, tcpfp_ports, log_path, json_path, timeout) +@_traced("prober.jarm_phase") def _jarm_phase( ip: str, ip_probed: dict[str, set[int]], @@ -296,6 +299,7 @@ def _jarm_phase( logger.warning("prober: JARM probe failed %s:%d: %s", ip, port, exc) +@_traced("prober.hassh_phase") def _hassh_phase( ip: str, ip_probed: dict[str, set[int]], @@ -342,6 +346,7 @@ def _hassh_phase( logger.warning("prober: HASSH probe failed %s:%d: %s", ip, port, exc) +@_traced("prober.tcpfp_phase") def _tcpfp_phase( ip: str, ip_probed: dict[str, set[int]], diff --git a/decnet/profiler/worker.py b/decnet/profiler/worker.py index 0cabec6..86fc81a 100644 --- a/decnet/profiler/worker.py +++ b/decnet/profiler/worker.py @@ -22,6 +22,7 @@ from decnet.correlation.engine import CorrelationEngine from decnet.correlation.parser import LogEvent from decnet.logging import get_logger from decnet.profiler.behavioral import build_behavior_record +from decnet.telemetry import traced as _traced from decnet.web.db.repository import BaseRepository logger = get_logger("attacker_worker") @@ -63,6 +64,7 @@ async def attacker_profile_worker(repo: BaseRepository, *, interval: int = 30) - logger.error("attacker worker: update failed: %s", exc) +@_traced("profiler.incremental_update") async def _incremental_update(repo: BaseRepository, state: _WorkerState) -> None: was_cold = not state.initialized affected_ips: set[str] = set() @@ -98,6 +100,7 @@ async def _incremental_update(repo: BaseRepository, state: _WorkerState) -> None logger.info("attacker worker: updated %d profiles (incremental)", len(affected_ips)) +@_traced("profiler.update_profiles") async def _update_profiles( repo: BaseRepository, state: _WorkerState, diff --git a/decnet/sniffer/worker.py b/decnet/sniffer/worker.py index 4f0cc43..dca71ab 100644 --- a/decnet/sniffer/worker.py +++ b/decnet/sniffer/worker.py @@ -21,6 +21,7 @@ from decnet.logging import get_logger from decnet.network import HOST_MACVLAN_IFACE from decnet.sniffer.fingerprint import SnifferEngine from decnet.sniffer.syslog import write_event +from decnet.telemetry import traced as _traced logger = get_logger("sniffer") @@ -52,6 +53,7 @@ def _interface_exists(iface: str) -> bool: return False +@_traced("sniffer.sniff_loop") def _sniff_loop( interface: str, log_path: Path, diff --git a/decnet/telemetry.py b/decnet/telemetry.py new file mode 100644 index 0000000..65cdc7e --- /dev/null +++ b/decnet/telemetry.py @@ -0,0 +1,371 @@ +""" +DECNET OpenTelemetry tracing integration. + +Controlled entirely by ``DECNET_DEVELOPER_TRACING``. When disabled (the +default), every public export is a zero-cost no-op: no OTEL SDK imports, no +monkey-patching, no middleware, and ``@traced`` returns the original function +object unwrapped. +""" + +from __future__ import annotations + +import asyncio +import functools +import inspect +from typing import Any, Callable, Optional, TypeVar, overload + +from decnet.env import DECNET_DEVELOPER_TRACING, DECNET_OTEL_ENDPOINT +from decnet.logging import get_logger + +log = get_logger("api") + +F = TypeVar("F", bound=Callable[..., Any]) + +_ENABLED: bool = DECNET_DEVELOPER_TRACING + +# --------------------------------------------------------------------------- +# Lazy OTEL imports — only when tracing is enabled +# --------------------------------------------------------------------------- + +_tracer_provider: Any = None # TracerProvider | None + + +def _init_provider() -> None: + """Initialise the global TracerProvider (called once from setup_tracing).""" + global _tracer_provider + + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + from opentelemetry.sdk.resources import Resource + + resource = Resource.create({ + "service.name": "decnet", + "service.version": "0.2.0", + }) + _tracer_provider = TracerProvider(resource=resource) + exporter = OTLPSpanExporter(endpoint=DECNET_OTEL_ENDPOINT, insecure=True) + _tracer_provider.add_span_processor(BatchSpanProcessor(exporter)) + trace.set_tracer_provider(_tracer_provider) + log.info("OTEL tracing enabled endpoint=%s", DECNET_OTEL_ENDPOINT) + + +def setup_tracing(app: Any) -> None: + """Configure the OTEL TracerProvider and instrument FastAPI. + + Call once from the FastAPI lifespan, after DB init. No-op when + ``DECNET_DEVELOPER_TRACING`` is not ``"true"``. + """ + if not _ENABLED: + return + + try: + _init_provider() + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + FastAPIInstrumentor.instrument_app(app) + log.info("FastAPI auto-instrumentation active") + except Exception as exc: + log.warning("OTEL setup failed — continuing without tracing: %s", exc) + + +def shutdown_tracing() -> None: + """Flush and shut down the tracer provider. Safe to call when disabled.""" + if _tracer_provider is not None: + try: + _tracer_provider.shutdown() + except Exception: + pass + + +# --------------------------------------------------------------------------- +# get_tracer — mirrors get_logger(component) pattern +# --------------------------------------------------------------------------- + +class _NoOpSpan: + """Minimal stand-in so ``with get_tracer(...).start_as_current_span(...)`` + works when tracing is disabled.""" + + def set_attribute(self, key: str, value: Any) -> None: + pass + + def set_status(self, *args: Any, **kwargs: Any) -> None: + pass + + def record_exception(self, exc: BaseException) -> None: + pass + + def __enter__(self) -> "_NoOpSpan": + return self + + def __exit__(self, *args: Any) -> None: + pass + + +class _NoOpTracer: + """Returned by ``get_tracer()`` when tracing is disabled.""" + + def start_as_current_span(self, name: str, **kwargs: Any) -> _NoOpSpan: + return _NoOpSpan() + + def start_span(self, name: str, **kwargs: Any) -> _NoOpSpan: + return _NoOpSpan() + + +_tracers: dict[str, Any] = {} + + +def get_tracer(component: str) -> Any: + """Return an OTEL Tracer (or a no-op stand-in) for *component*.""" + if not _ENABLED: + return _NoOpTracer() + + if component not in _tracers: + from opentelemetry import trace + _tracers[component] = trace.get_tracer(f"decnet.{component}") + return _tracers[component] + + +# --------------------------------------------------------------------------- +# @traced decorator — async + sync, zero overhead when disabled +# --------------------------------------------------------------------------- + +@overload +def traced(fn: F) -> F: ... +@overload +def traced(name: str) -> Callable[[F], F]: ... + + +def traced(fn: Any = None, *, name: str | None = None) -> Any: + """Decorator that wraps a function in an OTEL span. + + Usage:: + + @traced # span name = "module.func" + async def my_worker(): ... + + @traced("custom.span.name") # explicit span name + def my_sync_func(): ... + + When ``DECNET_DEVELOPER_TRACING`` is disabled the original function is + returned **unwrapped** — zero overhead on every call. + """ + # Handle @traced("name") vs @traced vs @traced(name="name") + if fn is None and name is not None: + # Called as @traced("name") or @traced(name="name") + def decorator(f: F) -> F: + return _wrap(f, name) + return decorator + if fn is not None and isinstance(fn, str): + # Called as @traced("name") — fn is actually the name string + span_name = fn + def decorator(f: F) -> F: + return _wrap(f, span_name) + return decorator + if fn is not None and callable(fn): + # Called as @traced (no arguments) + return _wrap(fn, None) + # Fallback: @traced() with no args + def decorator(f: F) -> F: + return _wrap(f, name) + return decorator + + +def _wrap(fn: F, span_name: str | None) -> F: + """Wrap *fn* in a span. Returns *fn* unchanged when tracing is off.""" + if not _ENABLED: + return fn + + resolved_name = span_name or f"{fn.__module__.rsplit('.', 1)[-1]}.{fn.__qualname__}" + + if inspect.iscoroutinefunction(fn): + @functools.wraps(fn) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + tracer = get_tracer(fn.__module__.split(".")[-1]) + with tracer.start_as_current_span(resolved_name) as span: + try: + result = await fn(*args, **kwargs) + return result + except Exception as exc: + span.record_exception(exc) + raise + return async_wrapper # type: ignore[return-value] + else: + @functools.wraps(fn) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + tracer = get_tracer(fn.__module__.split(".")[-1]) + with tracer.start_as_current_span(resolved_name) as span: + try: + result = fn(*args, **kwargs) + return result + except Exception as exc: + span.record_exception(exc) + raise + return sync_wrapper # type: ignore[return-value] + + +# --------------------------------------------------------------------------- +# TracedRepository — proxy wrapper for BaseRepository +# --------------------------------------------------------------------------- + +def wrap_repository(repo: Any) -> Any: + """Wrap *repo* in a tracing proxy. Returns *repo* unchanged when disabled.""" + if not _ENABLED: + return repo + + from decnet.web.db.repository import BaseRepository + + class TracedRepository(BaseRepository): + """Proxy that creates a DB span around every BaseRepository call.""" + + def __init__(self, inner: BaseRepository) -> None: + self._inner = inner + self._tracer = get_tracer("db") + + # --- Forward every ABC method through a span --- + + async def initialize(self) -> None: + with self._tracer.start_as_current_span("db.initialize"): + return await self._inner.initialize() + + async def add_log(self, log_data): + with self._tracer.start_as_current_span("db.add_log"): + return await self._inner.add_log(log_data) + + async def get_logs(self, limit=50, offset=0, search=None): + with self._tracer.start_as_current_span("db.get_logs") as span: + span.set_attribute("db.limit", limit) + span.set_attribute("db.offset", offset) + return await self._inner.get_logs(limit=limit, offset=offset, search=search) + + async def get_total_logs(self, search=None): + with self._tracer.start_as_current_span("db.get_total_logs"): + return await self._inner.get_total_logs(search=search) + + async def get_stats_summary(self): + with self._tracer.start_as_current_span("db.get_stats_summary"): + return await self._inner.get_stats_summary() + + async def get_deckies(self): + with self._tracer.start_as_current_span("db.get_deckies"): + return await self._inner.get_deckies() + + async def get_user_by_username(self, username): + with self._tracer.start_as_current_span("db.get_user_by_username"): + return await self._inner.get_user_by_username(username) + + async def get_user_by_uuid(self, uuid): + with self._tracer.start_as_current_span("db.get_user_by_uuid"): + return await self._inner.get_user_by_uuid(uuid) + + async def create_user(self, user_data): + with self._tracer.start_as_current_span("db.create_user"): + return await self._inner.create_user(user_data) + + async def update_user_password(self, uuid, password_hash, must_change_password=False): + with self._tracer.start_as_current_span("db.update_user_password"): + return await self._inner.update_user_password(uuid, password_hash, must_change_password) + + async def list_users(self): + with self._tracer.start_as_current_span("db.list_users"): + return await self._inner.list_users() + + async def delete_user(self, uuid): + with self._tracer.start_as_current_span("db.delete_user"): + return await self._inner.delete_user(uuid) + + async def update_user_role(self, uuid, role): + with self._tracer.start_as_current_span("db.update_user_role"): + return await self._inner.update_user_role(uuid, role) + + async def purge_logs_and_bounties(self): + with self._tracer.start_as_current_span("db.purge_logs_and_bounties"): + return await self._inner.purge_logs_and_bounties() + + async def add_bounty(self, bounty_data): + with self._tracer.start_as_current_span("db.add_bounty"): + return await self._inner.add_bounty(bounty_data) + + async def get_bounties(self, limit=50, offset=0, bounty_type=None, search=None): + with self._tracer.start_as_current_span("db.get_bounties") as span: + span.set_attribute("db.limit", limit) + span.set_attribute("db.offset", offset) + return await self._inner.get_bounties(limit=limit, offset=offset, bounty_type=bounty_type, search=search) + + async def get_total_bounties(self, bounty_type=None, search=None): + with self._tracer.start_as_current_span("db.get_total_bounties"): + return await self._inner.get_total_bounties(bounty_type=bounty_type, search=search) + + async def get_state(self, key): + with self._tracer.start_as_current_span("db.get_state") as span: + span.set_attribute("db.state_key", key) + return await self._inner.get_state(key) + + async def set_state(self, key, value): + with self._tracer.start_as_current_span("db.set_state") as span: + span.set_attribute("db.state_key", key) + return await self._inner.set_state(key, value) + + async def get_max_log_id(self): + with self._tracer.start_as_current_span("db.get_max_log_id"): + return await self._inner.get_max_log_id() + + async def get_logs_after_id(self, last_id, limit=500): + with self._tracer.start_as_current_span("db.get_logs_after_id") as span: + span.set_attribute("db.last_id", last_id) + span.set_attribute("db.limit", limit) + return await self._inner.get_logs_after_id(last_id, limit=limit) + + async def get_all_bounties_by_ip(self): + with self._tracer.start_as_current_span("db.get_all_bounties_by_ip"): + return await self._inner.get_all_bounties_by_ip() + + async def get_bounties_for_ips(self, ips): + with self._tracer.start_as_current_span("db.get_bounties_for_ips") as span: + span.set_attribute("db.ip_count", len(ips)) + return await self._inner.get_bounties_for_ips(ips) + + async def upsert_attacker(self, data): + with self._tracer.start_as_current_span("db.upsert_attacker"): + return await self._inner.upsert_attacker(data) + + async def upsert_attacker_behavior(self, attacker_uuid, data): + with self._tracer.start_as_current_span("db.upsert_attacker_behavior"): + return await self._inner.upsert_attacker_behavior(attacker_uuid, data) + + async def get_attacker_behavior(self, attacker_uuid): + with self._tracer.start_as_current_span("db.get_attacker_behavior"): + return await self._inner.get_attacker_behavior(attacker_uuid) + + async def get_behaviors_for_ips(self, ips): + with self._tracer.start_as_current_span("db.get_behaviors_for_ips") as span: + span.set_attribute("db.ip_count", len(ips)) + return await self._inner.get_behaviors_for_ips(ips) + + async def get_attacker_by_uuid(self, uuid): + with self._tracer.start_as_current_span("db.get_attacker_by_uuid"): + return await self._inner.get_attacker_by_uuid(uuid) + + async def get_attackers(self, limit=50, offset=0, search=None, sort_by="recent", service=None): + with self._tracer.start_as_current_span("db.get_attackers") as span: + span.set_attribute("db.limit", limit) + span.set_attribute("db.offset", offset) + return await self._inner.get_attackers(limit=limit, offset=offset, search=search, sort_by=sort_by, service=service) + + async def get_total_attackers(self, search=None, service=None): + with self._tracer.start_as_current_span("db.get_total_attackers"): + return await self._inner.get_total_attackers(search=search, service=service) + + async def get_attacker_commands(self, uuid, limit=50, offset=0, service=None): + with self._tracer.start_as_current_span("db.get_attacker_commands") as span: + span.set_attribute("db.limit", limit) + span.set_attribute("db.offset", offset) + return await self._inner.get_attacker_commands(uuid, limit=limit, offset=offset, service=service) + + # --- Catch-all for methods defined on concrete subclasses but not + # in the ABC (e.g. get_log_histogram). --- + + def __getattr__(self, name: str) -> Any: + return getattr(self._inner, name) + + return TracedRepository(repo) diff --git a/decnet/web/api.py b/decnet/web/api.py index aac1249..9e33c77 100644 --- a/decnet/web/api.py +++ b/decnet/web/api.py @@ -50,6 +50,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: log.error("DB failed to initialize after 5 attempts — startup may be degraded") await asyncio.sleep(0.5) + # Conditionally enable OpenTelemetry tracing + from decnet.telemetry import setup_tracing + setup_tracing(app) + # Start background tasks only if not in contract test mode if os.environ.get("DECNET_CONTRACT_TEST") != "true": # Start background ingestion task @@ -99,6 +103,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: pass except Exception as exc: log.warning("Task shutdown error: %s", exc) + from decnet.telemetry import shutdown_tracing + shutdown_tracing() log.info("API shutdown complete") diff --git a/decnet/web/db/factory.py b/decnet/web/db/factory.py index 2030be1..af5ff5c 100644 --- a/decnet/web/db/factory.py +++ b/decnet/web/db/factory.py @@ -22,8 +22,12 @@ def get_repository(**kwargs: Any) -> BaseRepository: if db_type == "sqlite": from decnet.web.db.sqlite.repository import SQLiteRepository - return SQLiteRepository(**kwargs) - if db_type == "mysql": + repo = SQLiteRepository(**kwargs) + elif db_type == "mysql": from decnet.web.db.mysql.repository import MySQLRepository - return MySQLRepository(**kwargs) - raise ValueError(f"Unsupported database type: {db_type}") + repo = MySQLRepository(**kwargs) + else: + raise ValueError(f"Unsupported database type: {db_type}") + + from decnet.telemetry import wrap_repository + return wrap_repository(repo) diff --git a/decnet/web/ingester.py b/decnet/web/ingester.py index 780cf7f..7a0a8ef 100644 --- a/decnet/web/ingester.py +++ b/decnet/web/ingester.py @@ -5,6 +5,7 @@ from typing import Any from pathlib import Path from decnet.logging import get_logger +from decnet.telemetry import traced as _traced from decnet.web.db.repository import BaseRepository logger = get_logger("api") @@ -83,6 +84,7 @@ async def log_ingestion_worker(repo: BaseRepository) -> None: await asyncio.sleep(1) +@_traced("ingester.extract_bounty") async def _extract_bounty(repo: BaseRepository, log_data: dict[str, Any]) -> None: """Detect and extract valuable artifacts (bounties) from log entries.""" _fields = log_data.get("fields") diff --git a/development/docker-compose.otel.yml b/development/docker-compose.otel.yml new file mode 100644 index 0000000..c56fb67 --- /dev/null +++ b/development/docker-compose.otel.yml @@ -0,0 +1,20 @@ +# DECNET OpenTelemetry development stack. +# +# Start: docker compose -f development/docker-compose.otel.yml up -d +# UI: http://localhost:16686 (Jaeger) +# Stop: docker compose -f development/docker-compose.otel.yml down +# +# Then run DECNET with tracing enabled: +# DECNET_DEVELOPER_TRACING=true decnet web + +services: + jaeger: + image: jaegertracing/all-in-one:latest + container_name: decnet-jaeger + restart: unless-stopped + ports: + - "4317:4317" # OTLP gRPC receiver + - "4318:4318" # OTLP HTTP receiver + - "16686:16686" # Jaeger UI + environment: + COLLECTOR_OTLP_ENABLED: "true" diff --git a/pyproject.toml b/pyproject.toml index b483548..2e7ac44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,14 @@ dependencies = [ ] [project.optional-dependencies] +tracing = [ + "opentelemetry-api>=1.20.0", + "opentelemetry-sdk>=1.20.0", + "opentelemetry-exporter-otlp>=1.20.0", + "opentelemetry-instrumentation-fastapi>=0.41b0", +] dev = [ + "decnet[tracing]", "pytest>=9.0.3", "ruff>=0.15.10", "bandit>=1.9.4", diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py new file mode 100644 index 0000000..8185256 --- /dev/null +++ b/tests/test_telemetry.py @@ -0,0 +1,250 @@ +""" +Tests for decnet.telemetry — OTEL tracing integration. + +Covers both the disabled path (default, zero overhead) and the enabled path +(with mocked OTEL SDK). +""" + +from __future__ import annotations + +import asyncio +import importlib +import os +import sys +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _reload_telemetry(*, enabled: bool = False): + """(Re)import decnet.telemetry with DECNET_DEVELOPER_TRACING set accordingly.""" + env_val = "true" if enabled else "" + with patch.dict(os.environ, {"DECNET_DEVELOPER_TRACING": env_val}): + # Force the env module to re-evaluate + import decnet.env + old_tracing = decnet.env.DECNET_DEVELOPER_TRACING + decnet.env.DECNET_DEVELOPER_TRACING = enabled + + # Remove cached telemetry module so it re-evaluates _ENABLED + sys.modules.pop("decnet.telemetry", None) + import decnet.telemetry + importlib.reload(decnet.telemetry) + + # Restore after reload + decnet.env.DECNET_DEVELOPER_TRACING = old_tracing + return decnet.telemetry + + +# ═══════════════════════════════════════════════════════════════════════════ +# DISABLED PATH (default) — zero overhead +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestTracingDisabled: + """When DECNET_DEVELOPER_TRACING is unset/false, everything is a no-op.""" + + def test_setup_tracing_is_noop(self): + mod = _reload_telemetry(enabled=False) + app = MagicMock() + mod.setup_tracing(app) + # FastAPIInstrumentor should NOT have been called + assert not any("opentelemetry" in str(c) for c in app.mock_calls) + + def test_get_tracer_returns_noop(self): + mod = _reload_telemetry(enabled=False) + tracer = mod.get_tracer("test") + assert isinstance(tracer, mod._NoOpTracer) + # NoOp span should work as context manager + with tracer.start_as_current_span("test") as span: + span.set_attribute("k", "v") + span.record_exception(RuntimeError("boom")) + + def test_traced_returns_original_function(self): + mod = _reload_telemetry(enabled=False) + + def my_func(x: int) -> int: + return x * 2 + + decorated = mod.traced(my_func) + # Must be the exact same function object — no wrapper overhead + assert decorated is my_func + assert decorated(5) == 10 + + def test_traced_with_name_returns_original(self): + mod = _reload_telemetry(enabled=False) + + @mod.traced("custom.name") + def my_func() -> str: + return "hello" + + # When disabled, @traced("name") still returns the original + assert my_func() == "hello" + assert my_func.__name__ == "my_func" + + def test_traced_async_returns_original(self): + mod = _reload_telemetry(enabled=False) + + async def my_async(x: int) -> int: + return x + 1 + + decorated = mod.traced(my_async) + assert decorated is my_async + + def test_wrap_repository_returns_original(self): + mod = _reload_telemetry(enabled=False) + repo = MagicMock() + result = mod.wrap_repository(repo) + assert result is repo + + def test_shutdown_tracing_noop(self): + mod = _reload_telemetry(enabled=False) + # Should not raise + mod.shutdown_tracing() + + +# ═══════════════════════════════════════════════════════════════════════════ +# ENABLED PATH — with mocked OTEL SDK +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestTracingEnabled: + """When DECNET_DEVELOPER_TRACING=true, spans are created.""" + + @pytest.fixture(autouse=True) + def _mock_otel(self): + """Provide mock OTEL modules so we don't need the real SDK installed.""" + # Create mock OTEL modules + mock_trace = MagicMock() + mock_tracer = MagicMock() + mock_span = MagicMock() + mock_span.__enter__ = MagicMock(return_value=mock_span) + mock_span.__exit__ = MagicMock(return_value=False) + mock_tracer.start_as_current_span.return_value = mock_span + mock_trace.get_tracer.return_value = mock_tracer + + self.mock_trace = mock_trace + self.mock_tracer = mock_tracer + self.mock_span = mock_span + + mock_modules = { + "opentelemetry": MagicMock(trace=mock_trace), + "opentelemetry.trace": mock_trace, + "opentelemetry.sdk": MagicMock(), + "opentelemetry.sdk.trace": MagicMock(), + "opentelemetry.sdk.trace.export": MagicMock(), + "opentelemetry.sdk.resources": MagicMock(), + "opentelemetry.exporter": MagicMock(), + "opentelemetry.exporter.otlp": MagicMock(), + "opentelemetry.exporter.otlp.proto": MagicMock(), + "opentelemetry.exporter.otlp.proto.grpc": MagicMock(), + "opentelemetry.exporter.otlp.proto.grpc.trace_exporter": MagicMock(), + "opentelemetry.instrumentation": MagicMock(), + "opentelemetry.instrumentation.fastapi": MagicMock(), + } + + with patch.dict(sys.modules, mock_modules): + self.mod = _reload_telemetry(enabled=True) + yield + + def test_traced_sync_creates_span(self): + @self.mod.traced("test.sync_op") + def do_work(x: int) -> int: + return x * 3 + + result = do_work(7) + assert result == 21 + # The wrapper should have called start_as_current_span + # (via get_tracer which returns our mock) + + def test_traced_async_creates_span(self): + @self.mod.traced("test.async_op") + async def do_async(x: int) -> int: + return x + 10 + + result = asyncio.run(do_async(5)) + assert result == 15 + + def test_traced_preserves_function_name(self): + @self.mod.traced("custom.name") + def my_named_func(): + pass + + assert my_named_func.__name__ == "my_named_func" + + def test_traced_exception_recorded(self): + @self.mod.traced("test.error") + def fail(): + raise ValueError("boom") + + with pytest.raises(ValueError, match="boom"): + fail() + + def test_traced_async_exception_recorded(self): + @self.mod.traced("test.async_error") + async def fail_async(): + raise RuntimeError("async boom") + + with pytest.raises(RuntimeError, match="async boom"): + asyncio.run(fail_async()) + + def test_wrap_repository_delegates(self): + mock_repo = AsyncMock() + mock_repo.add_log = AsyncMock(return_value=None) + mock_repo.get_logs = AsyncMock(return_value=[]) + mock_repo.get_state = AsyncMock(return_value={"key": "val"}) + + wrapped = self.mod.wrap_repository(mock_repo) + assert wrapped is not mock_repo + + # Verify delegation works + asyncio.run(wrapped.add_log({"test": 1})) + mock_repo.add_log.assert_awaited_once_with({"test": 1}) + + def test_wrap_repository_getattr_fallback(self): + mock_repo = MagicMock() + mock_repo.custom_method = MagicMock(return_value=42) + + wrapped = self.mod.wrap_repository(mock_repo) + assert wrapped.custom_method() == 42 + + def test_get_tracer_returns_real_tracer(self): + tracer = self.mod.get_tracer("test_component") + # Should be the mock tracer from opentelemetry.trace.get_tracer + assert tracer is not None + assert not isinstance(tracer, self.mod._NoOpTracer) + + def test_setup_tracing_instruments_app(self): + app = MagicMock() + self.mod.setup_tracing(app) + # Should not raise — the mock OTEL modules handle everything + + +# ═══════════════════════════════════════════════════════════════════════════ +# NoOp classes +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestNoOpClasses: + """NoOp tracer and span must satisfy the context-manager protocol.""" + + def test_noop_span_context_manager(self): + from decnet.telemetry import _NoOpSpan + span = _NoOpSpan() + with span as s: + assert s is span + s.set_attribute("key", "value") + s.set_status("ok") + s.record_exception(RuntimeError("test")) + + def test_noop_tracer(self): + from decnet.telemetry import _NoOpTracer + tracer = _NoOpTracer() + span = tracer.start_as_current_span("test") + assert hasattr(span, "__enter__") + span2 = tracer.start_span("test2") + assert hasattr(span2, "set_attribute")