fix(planner): guard apply_payload and reset_to_defaults with a lock

Concurrent PUT requests could observe a half-updated planner between
the four sequential global assignments. Added _planner_lock so the
rebind is atomic; same lock wraps reset_to_defaults.
This commit is contained in:
2026-04-30 21:15:12 -04:00
parent f597d70430
commit c7fcd86be4

View File

@@ -20,6 +20,7 @@ persona outside its window is never considered.
from __future__ import annotations
import secrets
import threading
from datetime import datetime
from typing import Any, Optional, Sequence
@@ -74,6 +75,7 @@ _USER_CLASS_WEIGHTS: tuple[tuple[ContentClass, int], ...] = _DEFAULT_USER_CLASS_
_SYSTEM_CLASS_WEIGHTS: tuple[tuple[ContentClass, int], ...] = _DEFAULT_SYSTEM_CLASS_WEIGHTS
_CANARY_CLASS_WEIGHTS: tuple[tuple[ContentClass, int], ...] = _DEFAULT_CANARY_CLASS_WEIGHTS
_CANARY_PROBABILITY: float = _DEFAULT_CANARY_PROBABILITY
_planner_lock = threading.Lock()
def _serialize_weights(
@@ -185,20 +187,22 @@ def apply_payload(payload: dict[str, Any]) -> None:
raise ValueError("canary_probability must be in [0.0, 1.0]")
new_prob = float(prob)
_USER_CLASS_WEIGHTS = new_user
_SYSTEM_CLASS_WEIGHTS = new_system
_CANARY_CLASS_WEIGHTS = new_canary
_CANARY_PROBABILITY = new_prob
with _planner_lock:
_USER_CLASS_WEIGHTS = new_user
_SYSTEM_CLASS_WEIGHTS = new_system
_CANARY_CLASS_WEIGHTS = new_canary
_CANARY_PROBABILITY = new_prob
def reset_to_defaults() -> None:
"""Restore hardcoded defaults. Used by tests and the API reset path."""
global _USER_CLASS_WEIGHTS, _SYSTEM_CLASS_WEIGHTS
global _CANARY_CLASS_WEIGHTS, _CANARY_PROBABILITY
_USER_CLASS_WEIGHTS = _DEFAULT_USER_CLASS_WEIGHTS
_SYSTEM_CLASS_WEIGHTS = _DEFAULT_SYSTEM_CLASS_WEIGHTS
_CANARY_CLASS_WEIGHTS = _DEFAULT_CANARY_CLASS_WEIGHTS
_CANARY_PROBABILITY = _DEFAULT_CANARY_PROBABILITY
with _planner_lock:
_USER_CLASS_WEIGHTS = _DEFAULT_USER_CLASS_WEIGHTS
_SYSTEM_CLASS_WEIGHTS = _DEFAULT_SYSTEM_CLASS_WEIGHTS
_CANARY_CLASS_WEIGHTS = _DEFAULT_CANARY_CLASS_WEIGHTS
_CANARY_PROBABILITY = _DEFAULT_CANARY_PROBABILITY
def _weighted_pick(