"""Trajectory metrics M1..M6 (LEARN-07, D-32) -- Task 4. Every session_exit writes one `trajectory_metric` event per metric. The CLI aggregator reads these events via aggregate_trajectory. Metrics (all computed in session-local scope): - M1: clarifying questions per session (decreasing over time) - M2: retrieval precision@5 (growing) - M3: tokens per session (decreasing) - M4: profile-vector variance (decreasing -> converged by session ~30) - M5: curiosity question frequency (entropy dropping) - M6: context-repeat rate (> 90% by session ~20) Plan 02-03 scope: event emission + basic aggregation. wires the CLI aggregator + synthetic-corpus benchmark. """ from __future__ import annotations from datetime import datetime, timezone from typing import Any from iai_mcp.events import query_events, write_event from iai_mcp.store import MemoryStore METRIC_NAMES: list[str] = ["m1", "m2", "m3", "m4", "m5", "m6"] # ---------------------------------------------------------------- emit def record_session_metrics( store: MemoryStore, session_id: str, metrics: dict[str, float], ) -> None: """Emit one `trajectory_metric` event per valid metric key in `metrics`. Keys outside METRIC_NAMES are ignored silently -- this is a public API; strict validation would force every test harness to chase whitespace in metric names. """ for m, v in metrics.items(): if m not in METRIC_NAMES: continue try: value = float(v) except (TypeError, ValueError): continue write_event( store, kind="trajectory_metric", data={"metric": m, "value": value}, severity="info", session_id=session_id, ) def aggregate_trajectory( store: MemoryStore, since: datetime | None = None, ) -> dict[str, list[tuple[datetime, float]]]: """CLI support: group all trajectory_metric events by metric. Returns {"m1": [(ts, value), ...], ..., "m6": [...]}. """ events = query_events( store, kind="trajectory_metric", since=since, limit=10000, ) out: dict[str, list[tuple[datetime, float]]] = {m: [] for m in METRIC_NAMES} for e in events: m = e["data"].get("metric") v = e["data"].get("value") if m in METRIC_NAMES and v is not None: try: out[m].append((e["ts"], float(v))) except (TypeError, ValueError): continue return out # ---------------------------------------------------------------- individual signals def compute_m1_clarifying_questions_per_session( store: MemoryStore, session_id: str, ) -> float: """M1: count of curiosity_question events for a session.""" events = query_events(store, kind="curiosity_question", limit=1000) count = sum(1 for e in events if e.get("session_id") == session_id) return float(count) def compute_m3_token_budget( store: MemoryStore, session_id: str, ) -> float: """M3: mean of session_start_tokens events for this session.""" events = query_events(store, kind="session_start_tokens", limit=100) session_events = [e for e in events if e.get("session_id") == session_id] if not session_events: return 0.0 total = 0.0 for e in session_events: try: total += float(e["data"].get("tokens", 0)) except (TypeError, ValueError): continue return total / len(session_events) def compute_m5_curiosity_frequency( store: MemoryStore, session_id: str, ) -> float: """M5: sum of curiosity_silent_log + curiosity_question events per session.""" silent = query_events(store, kind="curiosity_silent_log", limit=1000) questions = query_events(store, kind="curiosity_question", limit=1000) total = 0 for ev_list in (silent, questions): total += sum(1 for e in ev_list if e.get("session_id") == session_id) return float(total) def compute_session_metrics_snapshot( store: MemoryStore, session_id: str, ) -> dict[str, float]: """Produce a partial snapshot of M1..M6 from the current event stream. scope: M1/M3/M5 are computable from the event stream. promotion: M2/M4/M6 are now LIVE (read retrieval_used / profile_updated / session_started events emitted by retrieve.py / profile.py / session.py respectively). """ return { "m1": compute_m1_clarifying_questions_per_session(store, session_id), "m2": m2_precision_at_5_live(store), "m3": compute_m3_token_budget(store, session_id), "m4": m4_profile_variance_live(store), "m5": compute_m5_curiosity_frequency(store, session_id), "m6": m6_context_repeat_rate_live(store), } # -------------------------------------------------- M2/M4/M6 LIVE # Backward-compat synthetic constants (Phase 2 baseline; bench compares # live vs synthetic to prove the promotion is real -- see test_trajectory_live_smoke.py). M2_SYNTHETIC_CONSTANT: float = 0.0 M4_SYNTHETIC_CONSTANT: float = 0.0 M6_SYNTHETIC_CONSTANT: float = 0.0 def m2_precision_at_5_synthetic() -> float: """Pre-Plan-03-02 placeholder. Kept for trajectory bench comparison.""" return M2_SYNTHETIC_CONSTANT def m4_profile_variance_synthetic() -> float: """Pre-Plan-03-02 placeholder. Kept for trajectory bench comparison.""" return M4_SYNTHETIC_CONSTANT def m6_context_repeat_rate_synthetic() -> float: """Pre-Plan-03-02 placeholder. Kept for trajectory bench comparison.""" return M6_SYNTHETIC_CONSTANT def m2_precision_at_5_live( store: MemoryStore, *, window: int = 100, ) -> float: """M2 LIVE: precision@5 over the last ``window`` retrieval_used events. Each ``retrieval_used`` event carries ``hit_ids`` (list of UUID strings) and optionally a ``ground_truth`` list. When ground_truth is present, count hits in the top-5 that intersect ground_truth and divide by 5. When absent, fall back to the **hit-presence rate** -- (# events with at least one hit) / (# events) -- which is a coarse but honest proxy and never returns the synthetic 0.0 when the system is actually retrieving. The fallback path is what makes the live value differ from the synthetic constant in production -- the metric stops being a flat zero the moment retrieve.recall starts returning hits. """ events = query_events(store, kind="retrieval_used", limit=window) if not events: return 0.0 precisions: list[float] = [] for ev in events: data = ev.get("data") or {} hits = data.get("hit_ids") or [] ground_truth = set(data.get("ground_truth") or []) top5 = list(hits)[:5] if ground_truth: tp = sum(1 for h in top5 if h in ground_truth) precisions.append(tp / 5.0) else: # Fallback: hit-presence at top-5 (1.0 if any hit, else 0.0). precisions.append(1.0 if top5 else 0.0) if not precisions: return 0.0 return sum(precisions) / len(precisions) def m4_profile_variance_live( store: MemoryStore, *, n_updates: int = 20, ) -> float: """M4 LIVE: variance over the last N profile_updated events per knob. Aggregates the most recent ``n_updates`` ``profile_updated`` events, groups by knob, computes per-knob variance over the new values (only for numeric knobs -- bool/enum knobs are skipped), and returns the mean variance across knobs. Returns 0.0 when no events exist (back-compat with the synthetic baseline). """ events = query_events(store, kind="profile_updated", limit=n_updates * 5) if not events: return 0.0 per_knob: dict[str, list[float]] = {} for ev in events[:n_updates]: data = ev.get("data") or {} knob = data.get("knob") new_val = data.get("new") if knob is None or new_val is None: continue # Skip bool/enum knobs explicitly: bool is a subclass of int, so # float(True/False) succeeds; we want only int/float values. if isinstance(new_val, bool) or not isinstance(new_val, (int, float)): continue per_knob.setdefault(str(knob), []).append(float(new_val)) if not per_knob: return 0.0 variances: list[float] = [] for _knob, vals in per_knob.items(): if len(vals) < 2: variances.append(0.0) continue mean = sum(vals) / len(vals) var = sum((v - mean) ** 2 for v in vals) / len(vals) variances.append(var) if not variances: return 0.0 return sum(variances) / len(variances) def m6_context_repeat_rate_live( store: MemoryStore, *, window_days: int = 30, ) -> float: """M6 LIVE: context-repeat-rate over the last ``window_days`` of session_started. Reads ``kind='session_started'`` events with ``data.session_state_hash``, counts unique vs total hashes, and returns the *repeat rate*: repeat_rate = (total - unique) / total A value near 0.0 means every session looked novel; near 1.0 means heavy context reuse (which is the continuity ideal at session ~20+). """ from datetime import datetime, timedelta, timezone since = datetime.now(timezone.utc) - timedelta(days=window_days) events = query_events( store, kind="session_started", since=since, limit=10000, ) if not events: return 0.0 hashes: list[str] = [] for ev in events: data = ev.get("data") or {} hsh = data.get("session_state_hash") if hsh: hashes.append(str(hsh)) if not hashes: return 0.0 total = len(hashes) unique = len(set(hashes)) return (total - unique) / total def m2(store: MemoryStore) -> float: """Public M2 entry point (always live).""" return m2_precision_at_5_live(store) def m4(store: MemoryStore) -> float: """Public M4 entry point (always live).""" return m4_profile_variance_live(store) def m6(store: MemoryStore) -> float: """Public M6 entry point (always live).""" return m6_context_repeat_rate_live(store)