Initial release: iai-mcp v0.1.0
Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: XNLLLLH <XNLLLLH@users.noreply.github.com>
This commit is contained in:
commit
f6b876fbe7
332 changed files with 97258 additions and 0 deletions
328
bench/lme500/debug_pipeline_loss.py
Normal file
328
bench/lme500/debug_pipeline_loss.py
Normal file
|
|
@ -0,0 +1,328 @@
|
|||
"""bench/lme500/debug_pipeline_loss.py
|
||||
|
||||
Trace WHICH pipeline stage drops the gold session in loss cases
|
||||
(rows where retrieve_recall hits in top-k but recall_for_benchmark does not).
|
||||
|
||||
Usage:
|
||||
python bench/lme500/debug_pipeline_loss.py <question_id> [<question_id> ...]
|
||||
|
||||
For each qid:
|
||||
- Loads the LongMemEval-S row from the pinned dataset.
|
||||
- Builds a fresh per-row store + runtime graph (same shape as the bench).
|
||||
- Runs retrieve_recall to confirm gold sessions are findable by flat cosine.
|
||||
- Runs recall_for_benchmark STAGE BY STAGE, recording at each cut whether the
|
||||
gold record IDs survived.
|
||||
|
||||
Stages traced:
|
||||
Stage 2 — community gate (top-3 communities by centroid cosine)
|
||||
Stage 3 — seeds (top-3 by cosine within gated candidates)
|
||||
Stage 4 — 2-hop spread + rich-club union
|
||||
Stage 5 — final recall_for_benchmark hits
|
||||
|
||||
Output is a per-stage table showing where gold drops.
|
||||
|
||||
Read-only — no src/iai_mcp changes. Calls private helpers _community_gate
|
||||
and _pick_seeds for stage-level inspection (debug-only path).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
||||
|
||||
import numpy as np
|
||||
|
||||
from iai_mcp.embed import embedder_for_store
|
||||
from iai_mcp.pipeline import (
|
||||
_collect_graph_pool,
|
||||
_community_gate,
|
||||
_pick_seeds,
|
||||
recall_for_benchmark,
|
||||
)
|
||||
from iai_mcp.retrieve import build_runtime_graph, recall as retrieve_recall
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import MemoryRecord
|
||||
|
||||
from bench.adapters.longmemeval import LongMemEvalAdapter
|
||||
|
||||
|
||||
def _make_record(content: str, session_id: str, role: str, embedding: list[float]) -> MemoryRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier="episodic",
|
||||
literal_surface=content,
|
||||
aaak_index="",
|
||||
embedding=embedding,
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=2,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=False,
|
||||
never_merge=False,
|
||||
provenance=[],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=["longmemeval", f"role:{role}", f"session:{session_id}"],
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
def find_row(qid: str):
|
||||
adapter = LongMemEvalAdapter()
|
||||
sessions = []
|
||||
question = None
|
||||
answer_session_ids = None
|
||||
qtype = None
|
||||
for lme_session in adapter.load_dataset(split="S"):
|
||||
q = lme_session.queries[0]
|
||||
if q["question_id"] == qid:
|
||||
sessions.append(lme_session)
|
||||
if question is None:
|
||||
question = q["query"]
|
||||
answer_session_ids = set(q.get("relevant_turn_ids", []))
|
||||
qtype = q.get("question_type", "?")
|
||||
return question, qtype, answer_session_ids, sessions
|
||||
|
||||
|
||||
def trace_one(qid: str) -> dict:
|
||||
"""Returns a dict with the stage-by-stage gold survival counts."""
|
||||
print(f"\n{'=' * 78}\n=== qid={qid} ===\n{'=' * 78}", flush=True)
|
||||
question, qtype, gold_session_ids, sessions = find_row(qid)
|
||||
if question is None:
|
||||
print(f" qid={qid} NOT FOUND in dataset", flush=True)
|
||||
return {}
|
||||
|
||||
print(f" type={qtype}", flush=True)
|
||||
print(f" question[0:120]={question[:120]!r}", flush=True)
|
||||
print(f" gold session_ids={gold_session_ids}", flush=True)
|
||||
print(f" haystack sessions={len(sessions)}", flush=True)
|
||||
|
||||
tmp_root = Path(tempfile.mkdtemp(prefix="lme_dbg_"))
|
||||
store_dir = tmp_root / f"row-{qid}"
|
||||
store_dir.mkdir(parents=True, exist_ok=True)
|
||||
store = MemoryStore(path=store_dir / "lancedb")
|
||||
asyncio.run(store.enable_async_writes(coalesce_ms=50, max_batch=128))
|
||||
embedder = embedder_for_store(store)
|
||||
|
||||
id_to_session: dict[UUID, str] = {}
|
||||
gold_record_ids: set[UUID] = set()
|
||||
n_inserted = 0
|
||||
for sess in sessions:
|
||||
for turn in sess.turns:
|
||||
content = str(turn.get("content", "")).strip()
|
||||
if not content:
|
||||
continue
|
||||
vec = embedder.embed(content)
|
||||
rec = _make_record(
|
||||
content=content,
|
||||
session_id=sess.session_id,
|
||||
role=str(turn.get("role", "user")),
|
||||
embedding=vec,
|
||||
)
|
||||
store.insert(rec)
|
||||
id_to_session[rec.id] = sess.session_id
|
||||
if sess.session_id in gold_session_ids:
|
||||
gold_record_ids.add(rec.id)
|
||||
n_inserted += 1
|
||||
|
||||
asyncio.run(store.disable_async_writes())
|
||||
print(f" records inserted: {n_inserted}", flush=True)
|
||||
print(f" gold records: {len(gold_record_ids)}", flush=True)
|
||||
|
||||
graph, assignment, rich_club = build_runtime_graph(store)
|
||||
print(f" graph nodes: {len(graph._nx.nodes)}", flush=True)
|
||||
print(f" communities: {len(assignment.mid_regions)}", flush=True)
|
||||
print(f" rich-club: {len(rich_club)}", flush=True)
|
||||
cue_emb = embedder.embed(question)
|
||||
|
||||
# --- Baseline: retrieve_recall ---
|
||||
resp_x = retrieve_recall(
|
||||
store=store,
|
||||
cue_embedding=cue_emb,
|
||||
cue_text=question,
|
||||
session_id=f"debug-{qid}",
|
||||
budget_tokens=1500,
|
||||
k_hits=10,
|
||||
k_anti=0,
|
||||
)
|
||||
x_ids = [h.record_id for h in resp_x.hits]
|
||||
x_sessions = [id_to_session.get(r, "?") for r in x_ids]
|
||||
x_gold_pos = [i for i, s in enumerate(x_sessions) if s in gold_session_ids]
|
||||
print(f"\n --- retrieve_recall (X) ---", flush=True)
|
||||
print(f" top-10 sessions: {x_sessions}", flush=True)
|
||||
print(f" gold hit positions: {x_gold_pos}", flush=True)
|
||||
|
||||
# --- recall_for_benchmark, stage by stage ---
|
||||
print(f"\n --- recall_for_benchmark (Y) stage-by-stage ---", flush=True)
|
||||
|
||||
gated = _community_gate(cue_emb, assignment, top_n=3)
|
||||
candidates_set: set[UUID] = set()
|
||||
for gc in gated:
|
||||
for cid in assignment.mid_regions.get(gc, []):
|
||||
candidates_set.add(cid)
|
||||
if not candidates_set:
|
||||
candidates_set = {UUID(n) for n in graph._nx.nodes()}
|
||||
print(f" Stage 2 (community gate): EMPTY, fallback to all nodes", flush=True)
|
||||
print(f" Stage 2 (community gate): top-3 communities = {gated}", flush=True)
|
||||
print(f" candidates after gate: {len(candidates_set)}", flush=True)
|
||||
gold_in_gate = gold_record_ids & candidates_set
|
||||
print(f" gold survives gate: {len(gold_in_gate)} / {len(gold_record_ids)}", flush=True)
|
||||
|
||||
centrality: dict[UUID, float] = {}
|
||||
for nid in graph._nx.nodes:
|
||||
n = graph._nx.nodes[nid]
|
||||
if "centrality" in n:
|
||||
try:
|
||||
centrality[UUID(nid)] = float(n["centrality"])
|
||||
except (TypeError, ValueError):
|
||||
centrality[UUID(nid)] = 0.0
|
||||
if not centrality:
|
||||
try:
|
||||
centrality = graph.centrality()
|
||||
except Exception:
|
||||
centrality = {}
|
||||
# (08-01): _pick_seeds now reads from a shared cosine array.
|
||||
# Build the same array the production pipeline builds.
|
||||
pool_ids, pool_embs = _collect_graph_pool(graph, None, store)
|
||||
cue_vec_norm = np.asarray(cue_emb, dtype=np.float32)
|
||||
cn = float(np.linalg.norm(cue_vec_norm))
|
||||
if cn > 0.0:
|
||||
cue_vec_norm = cue_vec_norm / cn
|
||||
if pool_embs.size:
|
||||
shared_cos = (pool_embs @ cue_vec_norm).astype(np.float32)
|
||||
else:
|
||||
shared_cos = np.empty(0, dtype=np.float32)
|
||||
id_to_idx = {rid: i for i, rid in enumerate(pool_ids)}
|
||||
cand_idx = np.array(
|
||||
[id_to_idx[c] for c in candidates_set if c in id_to_idx],
|
||||
dtype=np.int64,
|
||||
)
|
||||
centrality_arr = np.array(
|
||||
[centrality.get(rid, 0.0) for rid in pool_ids],
|
||||
dtype=np.float32,
|
||||
)
|
||||
seed_idx = _pick_seeds(cand_idx, shared_cos, centrality_arr, n=3)
|
||||
seeds = [pool_ids[int(i)] for i in seed_idx]
|
||||
print(f" Stage 3 (seeds, top-3 by cosine in gated): {len(seeds)}", flush=True)
|
||||
seeds_sessions = [id_to_session.get(s, "?") for s in seeds]
|
||||
print(f" seed sessions: {seeds_sessions}", flush=True)
|
||||
gold_in_seeds = gold_record_ids & set(seeds)
|
||||
print(f" gold in seeds: {len(gold_in_seeds)}", flush=True)
|
||||
|
||||
spread = graph.two_hop_neighborhood(seeds, top_k=5)
|
||||
reachable = set(seeds) | set(spread) | set(rich_club)
|
||||
print(f" Stage 4 (spread + rich-club union):", flush=True)
|
||||
print(f" seeds={len(seeds)} spread={len(spread)} rich={len(rich_club)} reachable={len(reachable)}", flush=True)
|
||||
gold_in_reachable = gold_record_ids & reachable
|
||||
print(f" gold in reachable: {len(gold_in_reachable)} / {len(gold_record_ids)}", flush=True)
|
||||
|
||||
resp_y = recall_for_benchmark(
|
||||
store=store,
|
||||
graph=graph,
|
||||
assignment=assignment,
|
||||
rich_club=rich_club,
|
||||
embedder=embedder,
|
||||
cue=question,
|
||||
session_id=f"debug-{qid}",
|
||||
k_hits=10,
|
||||
profile_state=None,
|
||||
turn=0,
|
||||
mode="concept",
|
||||
)
|
||||
y_ids = [h.record_id for h in resp_y.hits]
|
||||
y_sessions = [id_to_session.get(r, "?") for r in y_ids]
|
||||
y_gold_pos = [i for i, s in enumerate(y_sessions) if s in gold_session_ids]
|
||||
print(f" Stage 5 (rank + budget pack):", flush=True)
|
||||
print(f" final hits: {len(y_ids)}", flush=True)
|
||||
print(f" top-10 sessions: {y_sessions}", flush=True)
|
||||
print(f" gold hit positions: {y_gold_pos}", flush=True)
|
||||
|
||||
# ----- Verdict -----
|
||||
# verdict primary signal is whether gold lands in
|
||||
# recall_for_benchmark's top-10 — which is what matters for R@5/R@10.
|
||||
# Stage-2/3/4 stage-by-stage diagnostics still print above (useful when
|
||||
# gold is missed) but they observe the PRIVATE _community_gate /
|
||||
# _pick_seeds path. The redesign (08-CONTEXT.md D-02) makes the
|
||||
# community gate a soft-bias diagnostic rather than a hard filter, so a
|
||||
# "stage_2 missed" diagnostic with gold present in final hits means:
|
||||
# the gate's communities did not include gold, but the cosine top-K
|
||||
# candidate pool did, and Stage 5 ranking surfaced it.
|
||||
print(f"\n --- VERDICT ---", flush=True)
|
||||
if y_gold_pos:
|
||||
print(f" gold present in top-10 (positions {y_gold_pos}) — no_loss", flush=True)
|
||||
if not gold_in_gate:
|
||||
print(f" (gate would have killed it; augmentation rescued)", flush=True)
|
||||
verdict = "no_loss"
|
||||
elif not gold_in_gate:
|
||||
print(f" >>> GOLD KILLED at STAGE 2 (community gate) — augmentation also failed <<<", flush=True)
|
||||
verdict = "stage_2_community_gate"
|
||||
elif not gold_in_reachable:
|
||||
print(f" >>> GOLD KILLED at STAGE 3-4 (seeds + spread) <<<", flush=True)
|
||||
print(f" gold was {len(gold_in_gate)} candidate(s); none became "
|
||||
f"a seed and none was reached within 2 hops of the chosen seeds", flush=True)
|
||||
verdict = "stage_3_4_seeds_or_spread"
|
||||
else:
|
||||
print(f" >>> GOLD KILLED at STAGE 5 (rank + budget pack) <<<", flush=True)
|
||||
print(f" gold was reachable ({len(gold_in_reachable)}) but not in top-10 hits", flush=True)
|
||||
verdict = "stage_5_rank"
|
||||
|
||||
return {
|
||||
"qid": qid,
|
||||
"qtype": qtype,
|
||||
"verdict": verdict,
|
||||
"n_records": n_inserted,
|
||||
"n_communities": len(assignment.mid_regions),
|
||||
"n_rich_club": len(rich_club),
|
||||
"n_gold_records": len(gold_record_ids),
|
||||
"gold_in_gate": len(gold_in_gate),
|
||||
"gold_in_reachable": len(gold_in_reachable),
|
||||
"x_gold_pos": x_gold_pos,
|
||||
"y_gold_pos": y_gold_pos,
|
||||
}
|
||||
|
||||
|
||||
def main(qids: list[str]) -> int:
|
||||
summary = []
|
||||
for qid in qids:
|
||||
try:
|
||||
summary.append(trace_one(qid))
|
||||
except Exception as exc:
|
||||
print(f"\n qid={qid} TRACE FAILED: {type(exc).__name__}: {exc}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
summary.append({"qid": qid, "verdict": "trace_failed"})
|
||||
|
||||
print("\n\n" + "=" * 78)
|
||||
print("SUMMARY")
|
||||
print("=" * 78)
|
||||
print(f"{'qid':16} {'qtype':28} {'verdict':32} gold(gate→reach)")
|
||||
print("-" * 100)
|
||||
for s in summary:
|
||||
if not s:
|
||||
continue
|
||||
gate = s.get("gold_in_gate", "?")
|
||||
reach = s.get("gold_in_reachable", "?")
|
||||
ngold = s.get("n_gold_records", "?")
|
||||
print(
|
||||
f"{s.get('qid', '?'):16} {s.get('qtype', '?'):28} "
|
||||
f"{s.get('verdict', '?'):32} "
|
||||
f"{gate}→{reach} (of {ngold})"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print(__doc__, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
Loading…
Add table
Add a link
Reference in a new issue