Initial release: iai-mcp v0.1.0

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: XNLLLLH <XNLLLLH@users.noreply.github.com>
This commit is contained in:
Areg Noya 2026-05-06 01:04:47 -07:00
commit f6b876fbe7
332 changed files with 97258 additions and 0 deletions

351
bench/lme500/aggregate.py Normal file
View file

@ -0,0 +1,351 @@
"""bench/lme500/aggregate.py — post-process LongMemEval-S blind-run output.
Usage:
python bench/lme500/aggregate.py \
--in bench/lme500/output/lme500-v1.json \
--report bench/lme500/output/lme500-v1-report.md \
--summary bench/lme500/output/lme500-v1-summary.json
The --in path may be:
- the final summary JSON ({"per_row": [...], ...} schema), or
- the per-row JSONL checkpoint (one JSON dict per line works on
partial runs while the bench is still in progress).
Computes:
- Overall R@5 / R@10 per prong (X = retrieve_recall, Y = recall_for_benchmark)
- Architecture lift Y - X
- Per-question-type stratification with n per bin (low-power flag if n<30)
- Bootstrap 95% CI via percentile method (10000 resamples, seed=42)
- Errors counted as miss for both prongs
Output:
- Markdown report (--report)
- Aggregated JSON summary (--summary)
- One-line stderr summary at end
"""
from __future__ import annotations
import argparse
import json
import random
import statistics
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any
def load_rows(input_path: Path) -> list[dict[str, Any]]:
"""Load per-row dicts from JSON, JSONL, or list-JSON.
Order of detection:
1. JSONL: every non-empty line parses as a dict.
2. JSON object with "per_row" key return per_row.
3. JSON list return as-is.
"""
text = input_path.read_text(encoding="utf-8")
stripped = text.strip()
# Try JSON first
if stripped.startswith("{"):
try:
data = json.loads(text)
if isinstance(data, dict) and "per_row" in data:
return list(data["per_row"])
except json.JSONDecodeError:
pass
if stripped.startswith("["):
try:
return list(json.loads(text))
except json.JSONDecodeError:
pass
# Fall back to JSONL
rows: list[dict[str, Any]] = []
for lineno, line in enumerate(text.splitlines(), 1):
line = line.strip()
if not line:
continue
try:
rows.append(json.loads(line))
except json.JSONDecodeError as exc:
print(
f"[aggregate] WARN: skipping corrupt line {lineno}: {exc}",
file=sys.stderr,
)
return rows
def bootstrap_ci(
values: list[float],
n_resamples: int = 10000,
seed: int = 42,
) -> tuple[float, float, float]:
"""Bootstrap mean + 95% percentile CI.
Returns (mean, ci_lo, ci_hi). Empty input (0, 0, 0).
"""
if not values:
return 0.0, 0.0, 0.0
rng = random.Random(seed)
n = len(values)
means: list[float] = []
for _ in range(n_resamples):
s = 0.0
for _ in range(n):
s += values[rng.randrange(n)]
means.append(s / n)
means.sort()
lo_idx = max(0, int(0.025 * n_resamples))
hi_idx = min(n_resamples - 1, int(0.975 * n_resamples))
return statistics.fmean(values), means[lo_idx], means[hi_idx]
def _get_prong_value(row: dict[str, Any], prong: str, k: int) -> float:
"""Extract r_at_<k>_<prong> from a row, treating error rows as 0."""
if "error" in row and isinstance(row.get("error"), dict):
return 0.0
return float(row.get(f"r_at_{k}_{prong}", 0.0))
def aggregate(rows: list[dict[str, Any]]) -> dict[str, Any]:
"""Aggregate overall + per-type bootstrap CIs."""
if not rows:
return {"overall": {"n": 0, "n_errors": 0}, "per_type": {}}
by_type: dict[str, dict[str, list[float]]] = defaultdict(
lambda: {"x5": [], "x10": [], "y5": [], "y10": []}
)
overall: dict[str, list[float]] = {"x5": [], "x10": [], "y5": [], "y10": []}
n_errors = 0
for row in rows:
is_error = "error" in row and isinstance(row.get("error"), dict)
if is_error:
n_errors += 1
qtype = str(row.get("question_type", "unknown"))
x5 = _get_prong_value(row, "retrieve", 5)
x10 = _get_prong_value(row, "retrieve", 10)
y5 = _get_prong_value(row, "pipeline", 5)
y10 = _get_prong_value(row, "pipeline", 10)
overall["x5"].append(x5)
overall["x10"].append(x10)
overall["y5"].append(y5)
overall["y10"].append(y10)
by_type[qtype]["x5"].append(x5)
by_type[qtype]["x10"].append(x10)
by_type[qtype]["y5"].append(y5)
by_type[qtype]["y10"].append(y10)
def _prong_block(vals_5: list[float], vals_10: list[float]) -> dict:
m5, lo5, hi5 = bootstrap_ci(vals_5)
m10, lo10, hi10 = bootstrap_ci(vals_10)
return {
"r_at_5": {"mean": m5, "ci_lo": lo5, "ci_hi": hi5},
"r_at_10": {"mean": m10, "ci_lo": lo10, "ci_hi": hi10},
}
overall_block = {
"n": len(rows),
"n_errors": n_errors,
"X_retrieve": _prong_block(overall["x5"], overall["x10"]),
"Y_pipeline": _prong_block(overall["y5"], overall["y10"]),
}
overall_block["lift_Y_minus_X"] = {
"r_at_5": (
overall_block["Y_pipeline"]["r_at_5"]["mean"]
- overall_block["X_retrieve"]["r_at_5"]["mean"]
),
"r_at_10": (
overall_block["Y_pipeline"]["r_at_10"]["mean"]
- overall_block["X_retrieve"]["r_at_10"]["mean"]
),
}
per_type_out: dict[str, dict[str, Any]] = {}
for qt in sorted(by_type.keys()):
data = by_type[qt]
block = {
"n": len(data["x5"]),
"X_retrieve": _prong_block(data["x5"], data["x10"]),
"Y_pipeline": _prong_block(data["y5"], data["y10"]),
}
block["lift_Y_minus_X"] = {
"r_at_5": (
block["Y_pipeline"]["r_at_5"]["mean"]
- block["X_retrieve"]["r_at_5"]["mean"]
),
"r_at_10": (
block["Y_pipeline"]["r_at_10"]["mean"]
- block["X_retrieve"]["r_at_10"]["mean"]
),
}
per_type_out[qt] = block
return {"overall": overall_block, "per_type": per_type_out}
def format_markdown_report(agg: dict[str, Any], source_path: Path) -> str:
overall = agg["overall"]
lines: list[str] = []
lines.append("# LongMemEval-S Aggregate Report")
lines.append("")
lines.append(f"- Source: `{source_path}`")
lines.append(f"- n = {overall['n']}, errors = {overall['n_errors']}")
lines.append(
"- 95% CI via bootstrap percentile method (10000 resamples, seed=42)"
)
lines.append("")
if overall["n"] == 0:
lines.append("**No rows loaded.**")
return "\n".join(lines) + "\n"
lines.append("## Overall")
lines.append("")
lines.append("| Prong | R@5 | R@5 95% CI | R@10 | R@10 95% CI |")
lines.append("|---|---|---|---|---|")
x = overall["X_retrieve"]
y = overall["Y_pipeline"]
lift = overall["lift_Y_minus_X"]
lines.append(
f"| X (retrieve_recall — flat-cosine baseline) "
f"| {x['r_at_5']['mean']:.3f} "
f"| [{x['r_at_5']['ci_lo']:.3f}, {x['r_at_5']['ci_hi']:.3f}] "
f"| {x['r_at_10']['mean']:.3f} "
f"| [{x['r_at_10']['ci_lo']:.3f}, {x['r_at_10']['ci_hi']:.3f}] |"
)
lines.append(
f"| Y (recall_for_benchmark — full graph-native pipeline) "
f"| {y['r_at_5']['mean']:.3f} "
f"| [{y['r_at_5']['ci_lo']:.3f}, {y['r_at_5']['ci_hi']:.3f}] "
f"| {y['r_at_10']['mean']:.3f} "
f"| [{y['r_at_10']['ci_lo']:.3f}, {y['r_at_10']['ci_hi']:.3f}] |"
)
lines.append(
f"| **Architecture lift Y X** "
f"| **{lift['r_at_5']:+.3f}** "
f"| — "
f"| **{lift['r_at_10']:+.3f}** "
f"| — |"
)
lines.append("")
lines.append("## Per question type")
lines.append("")
lines.append(
"| Type | n | X R@5 | Y R@5 | Lift R@5 "
"| X R@10 | Y R@10 | Lift R@10 |"
)
lines.append("|---|---|---|---|---|---|---|---|")
for qt, block in agg["per_type"].items():
n = block["n"]
flag = " ⚠️" if n < 30 else ""
x = block["X_retrieve"]
y = block["Y_pipeline"]
lift = block["lift_Y_minus_X"]
lines.append(
f"| `{qt}`{flag} | {n} "
f"| {x['r_at_5']['mean']:.3f} | {y['r_at_5']['mean']:.3f} "
f"| {lift['r_at_5']:+.3f} "
f"| {x['r_at_10']['mean']:.3f} | {y['r_at_10']['mean']:.3f} "
f"| {lift['r_at_10']:+.3f} |"
)
lines.append("")
lines.append("⚠️ = n < 30, low statistical power for that bin.")
lines.append("")
lines.append("## Notes")
lines.append("")
lines.append(
"- Errors (graph-build failures, malformed rows, etc.) are counted "
"as miss for **both** prongs (R@k = 0)."
)
lines.append(
"- Mean is the unweighted row average; CI is bootstrap percentile."
)
lines.append(
"- Architecture lift = mean(Y) mean(X). The CI of the lift "
"itself is not computed here (would require paired bootstrap on "
"the (Y_i, X_i) tuples — TODO if needed)."
)
return "\n".join(lines) + "\n"
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--in",
dest="input",
required=True,
help="Path to per-row JSON / JSONL file",
)
parser.add_argument(
"--report",
default=None,
help="Output path for markdown report; default: <input>-report.md",
)
parser.add_argument(
"--summary",
default=None,
help="Output path for aggregated JSON; default: <input>-summary.json",
)
args = parser.parse_args()
input_path = Path(args.input)
if not input_path.exists():
print(f"[aggregate] ERROR: {input_path} does not exist", file=sys.stderr)
return 1
rows = load_rows(input_path)
if not rows:
print(f"[aggregate] WARN: 0 rows loaded from {input_path}", file=sys.stderr)
return 1
agg = aggregate(rows)
summary_path = (
Path(args.summary)
if args.summary
else input_path.with_name(input_path.stem + "-summary.json")
)
summary_path.parent.mkdir(parents=True, exist_ok=True)
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(agg, f, indent=2)
report_path = (
Path(args.report)
if args.report
else input_path.with_name(input_path.stem + "-report.md")
)
report_path.parent.mkdir(parents=True, exist_ok=True)
report_path.write_text(format_markdown_report(agg, input_path), encoding="utf-8")
overall = agg["overall"]
x = overall["X_retrieve"]
y = overall["Y_pipeline"]
lift = overall["lift_Y_minus_X"]
print(
f"[aggregate] n={overall['n']} errors={overall['n_errors']}",
file=sys.stderr,
)
print(
f"[aggregate] X (retrieve) R@5={x['r_at_5']['mean']:.3f} "
f"[{x['r_at_5']['ci_lo']:.3f},{x['r_at_5']['ci_hi']:.3f}] "
f"R@10={x['r_at_10']['mean']:.3f}",
file=sys.stderr,
)
print(
f"[aggregate] Y (pipeline) R@5={y['r_at_5']['mean']:.3f} "
f"[{y['r_at_5']['ci_lo']:.3f},{y['r_at_5']['ci_hi']:.3f}] "
f"R@10={y['r_at_10']['mean']:.3f}",
file=sys.stderr,
)
print(
f"[aggregate] Lift Y X R@5={lift['r_at_5']:+.3f} "
f"R@10={lift['r_at_10']:+.3f}",
file=sys.stderr,
)
print(f"[aggregate] -> {summary_path}", file=sys.stderr)
print(f"[aggregate] -> {report_path}", file=sys.stderr)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View file

@ -0,0 +1,328 @@
"""bench/lme500/debug_pipeline_loss.py
Trace WHICH pipeline stage drops the gold session in loss cases
(rows where retrieve_recall hits in top-k but recall_for_benchmark does not).
Usage:
python bench/lme500/debug_pipeline_loss.py <question_id> [<question_id> ...]
For each qid:
- Loads the LongMemEval-S row from the pinned dataset.
- Builds a fresh per-row store + runtime graph (same shape as the bench).
- Runs retrieve_recall to confirm gold sessions are findable by flat cosine.
- Runs recall_for_benchmark STAGE BY STAGE, recording at each cut whether the
gold record IDs survived.
Stages traced:
Stage 2 community gate (top-3 communities by centroid cosine)
Stage 3 seeds (top-3 by cosine within gated candidates)
Stage 4 2-hop spread + rich-club union
Stage 5 final recall_for_benchmark hits
Output is a per-stage table showing where gold drops.
Read-only no src/iai_mcp changes. Calls private helpers _community_gate
and _pick_seeds for stage-level inspection (debug-only path).
"""
from __future__ import annotations
import asyncio
import os
import sys
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from uuid import UUID, uuid4
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
import numpy as np
from iai_mcp.embed import embedder_for_store
from iai_mcp.pipeline import (
_collect_graph_pool,
_community_gate,
_pick_seeds,
recall_for_benchmark,
)
from iai_mcp.retrieve import build_runtime_graph, recall as retrieve_recall
from iai_mcp.store import MemoryStore
from iai_mcp.types import MemoryRecord
from bench.adapters.longmemeval import LongMemEvalAdapter
def _make_record(content: str, session_id: str, role: str, embedding: list[float]) -> MemoryRecord:
now = datetime.now(timezone.utc)
return MemoryRecord(
id=uuid4(),
tier="episodic",
literal_surface=content,
aaak_index="",
embedding=embedding,
community_id=None,
centrality=0.0,
detail_level=2,
pinned=False,
stability=0.0,
difficulty=0.0,
last_reviewed=None,
never_decay=False,
never_merge=False,
provenance=[],
created_at=now,
updated_at=now,
tags=["longmemeval", f"role:{role}", f"session:{session_id}"],
language="en",
)
def find_row(qid: str):
adapter = LongMemEvalAdapter()
sessions = []
question = None
answer_session_ids = None
qtype = None
for lme_session in adapter.load_dataset(split="S"):
q = lme_session.queries[0]
if q["question_id"] == qid:
sessions.append(lme_session)
if question is None:
question = q["query"]
answer_session_ids = set(q.get("relevant_turn_ids", []))
qtype = q.get("question_type", "?")
return question, qtype, answer_session_ids, sessions
def trace_one(qid: str) -> dict:
"""Returns a dict with the stage-by-stage gold survival counts."""
print(f"\n{'=' * 78}\n=== qid={qid} ===\n{'=' * 78}", flush=True)
question, qtype, gold_session_ids, sessions = find_row(qid)
if question is None:
print(f" qid={qid} NOT FOUND in dataset", flush=True)
return {}
print(f" type={qtype}", flush=True)
print(f" question[0:120]={question[:120]!r}", flush=True)
print(f" gold session_ids={gold_session_ids}", flush=True)
print(f" haystack sessions={len(sessions)}", flush=True)
tmp_root = Path(tempfile.mkdtemp(prefix="lme_dbg_"))
store_dir = tmp_root / f"row-{qid}"
store_dir.mkdir(parents=True, exist_ok=True)
store = MemoryStore(path=store_dir / "lancedb")
asyncio.run(store.enable_async_writes(coalesce_ms=50, max_batch=128))
embedder = embedder_for_store(store)
id_to_session: dict[UUID, str] = {}
gold_record_ids: set[UUID] = set()
n_inserted = 0
for sess in sessions:
for turn in sess.turns:
content = str(turn.get("content", "")).strip()
if not content:
continue
vec = embedder.embed(content)
rec = _make_record(
content=content,
session_id=sess.session_id,
role=str(turn.get("role", "user")),
embedding=vec,
)
store.insert(rec)
id_to_session[rec.id] = sess.session_id
if sess.session_id in gold_session_ids:
gold_record_ids.add(rec.id)
n_inserted += 1
asyncio.run(store.disable_async_writes())
print(f" records inserted: {n_inserted}", flush=True)
print(f" gold records: {len(gold_record_ids)}", flush=True)
graph, assignment, rich_club = build_runtime_graph(store)
print(f" graph nodes: {len(graph._nx.nodes)}", flush=True)
print(f" communities: {len(assignment.mid_regions)}", flush=True)
print(f" rich-club: {len(rich_club)}", flush=True)
cue_emb = embedder.embed(question)
# --- Baseline: retrieve_recall ---
resp_x = retrieve_recall(
store=store,
cue_embedding=cue_emb,
cue_text=question,
session_id=f"debug-{qid}",
budget_tokens=1500,
k_hits=10,
k_anti=0,
)
x_ids = [h.record_id for h in resp_x.hits]
x_sessions = [id_to_session.get(r, "?") for r in x_ids]
x_gold_pos = [i for i, s in enumerate(x_sessions) if s in gold_session_ids]
print(f"\n --- retrieve_recall (X) ---", flush=True)
print(f" top-10 sessions: {x_sessions}", flush=True)
print(f" gold hit positions: {x_gold_pos}", flush=True)
# --- recall_for_benchmark, stage by stage ---
print(f"\n --- recall_for_benchmark (Y) stage-by-stage ---", flush=True)
gated = _community_gate(cue_emb, assignment, top_n=3)
candidates_set: set[UUID] = set()
for gc in gated:
for cid in assignment.mid_regions.get(gc, []):
candidates_set.add(cid)
if not candidates_set:
candidates_set = {UUID(n) for n in graph._nx.nodes()}
print(f" Stage 2 (community gate): EMPTY, fallback to all nodes", flush=True)
print(f" Stage 2 (community gate): top-3 communities = {gated}", flush=True)
print(f" candidates after gate: {len(candidates_set)}", flush=True)
gold_in_gate = gold_record_ids & candidates_set
print(f" gold survives gate: {len(gold_in_gate)} / {len(gold_record_ids)}", flush=True)
centrality: dict[UUID, float] = {}
for nid in graph._nx.nodes:
n = graph._nx.nodes[nid]
if "centrality" in n:
try:
centrality[UUID(nid)] = float(n["centrality"])
except (TypeError, ValueError):
centrality[UUID(nid)] = 0.0
if not centrality:
try:
centrality = graph.centrality()
except Exception:
centrality = {}
# (08-01): _pick_seeds now reads from a shared cosine array.
# Build the same array the production pipeline builds.
pool_ids, pool_embs = _collect_graph_pool(graph, None, store)
cue_vec_norm = np.asarray(cue_emb, dtype=np.float32)
cn = float(np.linalg.norm(cue_vec_norm))
if cn > 0.0:
cue_vec_norm = cue_vec_norm / cn
if pool_embs.size:
shared_cos = (pool_embs @ cue_vec_norm).astype(np.float32)
else:
shared_cos = np.empty(0, dtype=np.float32)
id_to_idx = {rid: i for i, rid in enumerate(pool_ids)}
cand_idx = np.array(
[id_to_idx[c] for c in candidates_set if c in id_to_idx],
dtype=np.int64,
)
centrality_arr = np.array(
[centrality.get(rid, 0.0) for rid in pool_ids],
dtype=np.float32,
)
seed_idx = _pick_seeds(cand_idx, shared_cos, centrality_arr, n=3)
seeds = [pool_ids[int(i)] for i in seed_idx]
print(f" Stage 3 (seeds, top-3 by cosine in gated): {len(seeds)}", flush=True)
seeds_sessions = [id_to_session.get(s, "?") for s in seeds]
print(f" seed sessions: {seeds_sessions}", flush=True)
gold_in_seeds = gold_record_ids & set(seeds)
print(f" gold in seeds: {len(gold_in_seeds)}", flush=True)
spread = graph.two_hop_neighborhood(seeds, top_k=5)
reachable = set(seeds) | set(spread) | set(rich_club)
print(f" Stage 4 (spread + rich-club union):", flush=True)
print(f" seeds={len(seeds)} spread={len(spread)} rich={len(rich_club)} reachable={len(reachable)}", flush=True)
gold_in_reachable = gold_record_ids & reachable
print(f" gold in reachable: {len(gold_in_reachable)} / {len(gold_record_ids)}", flush=True)
resp_y = recall_for_benchmark(
store=store,
graph=graph,
assignment=assignment,
rich_club=rich_club,
embedder=embedder,
cue=question,
session_id=f"debug-{qid}",
k_hits=10,
profile_state=None,
turn=0,
mode="concept",
)
y_ids = [h.record_id for h in resp_y.hits]
y_sessions = [id_to_session.get(r, "?") for r in y_ids]
y_gold_pos = [i for i, s in enumerate(y_sessions) if s in gold_session_ids]
print(f" Stage 5 (rank + budget pack):", flush=True)
print(f" final hits: {len(y_ids)}", flush=True)
print(f" top-10 sessions: {y_sessions}", flush=True)
print(f" gold hit positions: {y_gold_pos}", flush=True)
# ----- Verdict -----
# verdict primary signal is whether gold lands in
# recall_for_benchmark's top-10 — which is what matters for R@5/R@10.
# Stage-2/3/4 stage-by-stage diagnostics still print above (useful when
# gold is missed) but they observe the PRIVATE _community_gate /
# _pick_seeds path. The redesign (08-CONTEXT.md D-02) makes the
# community gate a soft-bias diagnostic rather than a hard filter, so a
# "stage_2 missed" diagnostic with gold present in final hits means:
# the gate's communities did not include gold, but the cosine top-K
# candidate pool did, and Stage 5 ranking surfaced it.
print(f"\n --- VERDICT ---", flush=True)
if y_gold_pos:
print(f" gold present in top-10 (positions {y_gold_pos}) — no_loss", flush=True)
if not gold_in_gate:
print(f" (gate would have killed it; augmentation rescued)", flush=True)
verdict = "no_loss"
elif not gold_in_gate:
print(f" >>> GOLD KILLED at STAGE 2 (community gate) — augmentation also failed <<<", flush=True)
verdict = "stage_2_community_gate"
elif not gold_in_reachable:
print(f" >>> GOLD KILLED at STAGE 3-4 (seeds + spread) <<<", flush=True)
print(f" gold was {len(gold_in_gate)} candidate(s); none became "
f"a seed and none was reached within 2 hops of the chosen seeds", flush=True)
verdict = "stage_3_4_seeds_or_spread"
else:
print(f" >>> GOLD KILLED at STAGE 5 (rank + budget pack) <<<", flush=True)
print(f" gold was reachable ({len(gold_in_reachable)}) but not in top-10 hits", flush=True)
verdict = "stage_5_rank"
return {
"qid": qid,
"qtype": qtype,
"verdict": verdict,
"n_records": n_inserted,
"n_communities": len(assignment.mid_regions),
"n_rich_club": len(rich_club),
"n_gold_records": len(gold_record_ids),
"gold_in_gate": len(gold_in_gate),
"gold_in_reachable": len(gold_in_reachable),
"x_gold_pos": x_gold_pos,
"y_gold_pos": y_gold_pos,
}
def main(qids: list[str]) -> int:
summary = []
for qid in qids:
try:
summary.append(trace_one(qid))
except Exception as exc:
print(f"\n qid={qid} TRACE FAILED: {type(exc).__name__}: {exc}", flush=True)
import traceback
traceback.print_exc()
summary.append({"qid": qid, "verdict": "trace_failed"})
print("\n\n" + "=" * 78)
print("SUMMARY")
print("=" * 78)
print(f"{'qid':16} {'qtype':28} {'verdict':32} gold(gate→reach)")
print("-" * 100)
for s in summary:
if not s:
continue
gate = s.get("gold_in_gate", "?")
reach = s.get("gold_in_reachable", "?")
ngold = s.get("n_gold_records", "?")
print(
f"{s.get('qid', '?'):16} {s.get('qtype', '?'):28} "
f"{s.get('verdict', '?'):32} "
f"{gate}{reach} (of {ngold})"
)
return 0
if __name__ == "__main__":
if len(sys.argv) < 2:
print(__doc__, file=sys.stderr)
sys.exit(1)
sys.exit(main(sys.argv[1:]))