Initial release: iai-mcp v0.1.0
Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: XNLLLLH <XNLLLLH@users.noreply.github.com>
This commit is contained in:
commit
f6b876fbe7
332 changed files with 97258 additions and 0 deletions
351
bench/lme500/aggregate.py
Normal file
351
bench/lme500/aggregate.py
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
"""bench/lme500/aggregate.py — post-process LongMemEval-S blind-run output.
|
||||
|
||||
Usage:
|
||||
python bench/lme500/aggregate.py \
|
||||
--in bench/lme500/output/lme500-v1.json \
|
||||
--report bench/lme500/output/lme500-v1-report.md \
|
||||
--summary bench/lme500/output/lme500-v1-summary.json
|
||||
|
||||
The --in path may be:
|
||||
- the final summary JSON ({"per_row": [...], ...} schema), or
|
||||
- the per-row JSONL checkpoint (one JSON dict per line — works on
|
||||
partial runs while the bench is still in progress).
|
||||
|
||||
Computes:
|
||||
- Overall R@5 / R@10 per prong (X = retrieve_recall, Y = recall_for_benchmark)
|
||||
- Architecture lift Y - X
|
||||
- Per-question-type stratification with n per bin (low-power flag if n<30)
|
||||
- Bootstrap 95% CI via percentile method (10000 resamples, seed=42)
|
||||
- Errors counted as miss for both prongs
|
||||
|
||||
Output:
|
||||
- Markdown report (--report)
|
||||
- Aggregated JSON summary (--summary)
|
||||
- One-line stderr summary at end
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import statistics
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def load_rows(input_path: Path) -> list[dict[str, Any]]:
|
||||
"""Load per-row dicts from JSON, JSONL, or list-JSON.
|
||||
|
||||
Order of detection:
|
||||
1. JSONL: every non-empty line parses as a dict.
|
||||
2. JSON object with "per_row" key → return per_row.
|
||||
3. JSON list → return as-is.
|
||||
"""
|
||||
text = input_path.read_text(encoding="utf-8")
|
||||
stripped = text.strip()
|
||||
# Try JSON first
|
||||
if stripped.startswith("{"):
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if isinstance(data, dict) and "per_row" in data:
|
||||
return list(data["per_row"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if stripped.startswith("["):
|
||||
try:
|
||||
return list(json.loads(text))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
# Fall back to JSONL
|
||||
rows: list[dict[str, Any]] = []
|
||||
for lineno, line in enumerate(text.splitlines(), 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
rows.append(json.loads(line))
|
||||
except json.JSONDecodeError as exc:
|
||||
print(
|
||||
f"[aggregate] WARN: skipping corrupt line {lineno}: {exc}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def bootstrap_ci(
|
||||
values: list[float],
|
||||
n_resamples: int = 10000,
|
||||
seed: int = 42,
|
||||
) -> tuple[float, float, float]:
|
||||
"""Bootstrap mean + 95% percentile CI.
|
||||
|
||||
Returns (mean, ci_lo, ci_hi). Empty input → (0, 0, 0).
|
||||
"""
|
||||
if not values:
|
||||
return 0.0, 0.0, 0.0
|
||||
rng = random.Random(seed)
|
||||
n = len(values)
|
||||
means: list[float] = []
|
||||
for _ in range(n_resamples):
|
||||
s = 0.0
|
||||
for _ in range(n):
|
||||
s += values[rng.randrange(n)]
|
||||
means.append(s / n)
|
||||
means.sort()
|
||||
lo_idx = max(0, int(0.025 * n_resamples))
|
||||
hi_idx = min(n_resamples - 1, int(0.975 * n_resamples))
|
||||
return statistics.fmean(values), means[lo_idx], means[hi_idx]
|
||||
|
||||
|
||||
def _get_prong_value(row: dict[str, Any], prong: str, k: int) -> float:
|
||||
"""Extract r_at_<k>_<prong> from a row, treating error rows as 0."""
|
||||
if "error" in row and isinstance(row.get("error"), dict):
|
||||
return 0.0
|
||||
return float(row.get(f"r_at_{k}_{prong}", 0.0))
|
||||
|
||||
|
||||
def aggregate(rows: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""Aggregate overall + per-type bootstrap CIs."""
|
||||
if not rows:
|
||||
return {"overall": {"n": 0, "n_errors": 0}, "per_type": {}}
|
||||
|
||||
by_type: dict[str, dict[str, list[float]]] = defaultdict(
|
||||
lambda: {"x5": [], "x10": [], "y5": [], "y10": []}
|
||||
)
|
||||
overall: dict[str, list[float]] = {"x5": [], "x10": [], "y5": [], "y10": []}
|
||||
n_errors = 0
|
||||
|
||||
for row in rows:
|
||||
is_error = "error" in row and isinstance(row.get("error"), dict)
|
||||
if is_error:
|
||||
n_errors += 1
|
||||
qtype = str(row.get("question_type", "unknown"))
|
||||
x5 = _get_prong_value(row, "retrieve", 5)
|
||||
x10 = _get_prong_value(row, "retrieve", 10)
|
||||
y5 = _get_prong_value(row, "pipeline", 5)
|
||||
y10 = _get_prong_value(row, "pipeline", 10)
|
||||
overall["x5"].append(x5)
|
||||
overall["x10"].append(x10)
|
||||
overall["y5"].append(y5)
|
||||
overall["y10"].append(y10)
|
||||
by_type[qtype]["x5"].append(x5)
|
||||
by_type[qtype]["x10"].append(x10)
|
||||
by_type[qtype]["y5"].append(y5)
|
||||
by_type[qtype]["y10"].append(y10)
|
||||
|
||||
def _prong_block(vals_5: list[float], vals_10: list[float]) -> dict:
|
||||
m5, lo5, hi5 = bootstrap_ci(vals_5)
|
||||
m10, lo10, hi10 = bootstrap_ci(vals_10)
|
||||
return {
|
||||
"r_at_5": {"mean": m5, "ci_lo": lo5, "ci_hi": hi5},
|
||||
"r_at_10": {"mean": m10, "ci_lo": lo10, "ci_hi": hi10},
|
||||
}
|
||||
|
||||
overall_block = {
|
||||
"n": len(rows),
|
||||
"n_errors": n_errors,
|
||||
"X_retrieve": _prong_block(overall["x5"], overall["x10"]),
|
||||
"Y_pipeline": _prong_block(overall["y5"], overall["y10"]),
|
||||
}
|
||||
overall_block["lift_Y_minus_X"] = {
|
||||
"r_at_5": (
|
||||
overall_block["Y_pipeline"]["r_at_5"]["mean"]
|
||||
- overall_block["X_retrieve"]["r_at_5"]["mean"]
|
||||
),
|
||||
"r_at_10": (
|
||||
overall_block["Y_pipeline"]["r_at_10"]["mean"]
|
||||
- overall_block["X_retrieve"]["r_at_10"]["mean"]
|
||||
),
|
||||
}
|
||||
|
||||
per_type_out: dict[str, dict[str, Any]] = {}
|
||||
for qt in sorted(by_type.keys()):
|
||||
data = by_type[qt]
|
||||
block = {
|
||||
"n": len(data["x5"]),
|
||||
"X_retrieve": _prong_block(data["x5"], data["x10"]),
|
||||
"Y_pipeline": _prong_block(data["y5"], data["y10"]),
|
||||
}
|
||||
block["lift_Y_minus_X"] = {
|
||||
"r_at_5": (
|
||||
block["Y_pipeline"]["r_at_5"]["mean"]
|
||||
- block["X_retrieve"]["r_at_5"]["mean"]
|
||||
),
|
||||
"r_at_10": (
|
||||
block["Y_pipeline"]["r_at_10"]["mean"]
|
||||
- block["X_retrieve"]["r_at_10"]["mean"]
|
||||
),
|
||||
}
|
||||
per_type_out[qt] = block
|
||||
|
||||
return {"overall": overall_block, "per_type": per_type_out}
|
||||
|
||||
|
||||
def format_markdown_report(agg: dict[str, Any], source_path: Path) -> str:
|
||||
overall = agg["overall"]
|
||||
lines: list[str] = []
|
||||
lines.append("# LongMemEval-S Aggregate Report")
|
||||
lines.append("")
|
||||
lines.append(f"- Source: `{source_path}`")
|
||||
lines.append(f"- n = {overall['n']}, errors = {overall['n_errors']}")
|
||||
lines.append(
|
||||
"- 95% CI via bootstrap percentile method (10000 resamples, seed=42)"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
if overall["n"] == 0:
|
||||
lines.append("**No rows loaded.**")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
lines.append("## Overall")
|
||||
lines.append("")
|
||||
lines.append("| Prong | R@5 | R@5 95% CI | R@10 | R@10 95% CI |")
|
||||
lines.append("|---|---|---|---|---|")
|
||||
x = overall["X_retrieve"]
|
||||
y = overall["Y_pipeline"]
|
||||
lift = overall["lift_Y_minus_X"]
|
||||
lines.append(
|
||||
f"| X (retrieve_recall — flat-cosine baseline) "
|
||||
f"| {x['r_at_5']['mean']:.3f} "
|
||||
f"| [{x['r_at_5']['ci_lo']:.3f}, {x['r_at_5']['ci_hi']:.3f}] "
|
||||
f"| {x['r_at_10']['mean']:.3f} "
|
||||
f"| [{x['r_at_10']['ci_lo']:.3f}, {x['r_at_10']['ci_hi']:.3f}] |"
|
||||
)
|
||||
lines.append(
|
||||
f"| Y (recall_for_benchmark — full graph-native pipeline) "
|
||||
f"| {y['r_at_5']['mean']:.3f} "
|
||||
f"| [{y['r_at_5']['ci_lo']:.3f}, {y['r_at_5']['ci_hi']:.3f}] "
|
||||
f"| {y['r_at_10']['mean']:.3f} "
|
||||
f"| [{y['r_at_10']['ci_lo']:.3f}, {y['r_at_10']['ci_hi']:.3f}] |"
|
||||
)
|
||||
lines.append(
|
||||
f"| **Architecture lift Y − X** "
|
||||
f"| **{lift['r_at_5']:+.3f}** "
|
||||
f"| — "
|
||||
f"| **{lift['r_at_10']:+.3f}** "
|
||||
f"| — |"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Per question type")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
"| Type | n | X R@5 | Y R@5 | Lift R@5 "
|
||||
"| X R@10 | Y R@10 | Lift R@10 |"
|
||||
)
|
||||
lines.append("|---|---|---|---|---|---|---|---|")
|
||||
for qt, block in agg["per_type"].items():
|
||||
n = block["n"]
|
||||
flag = " ⚠️" if n < 30 else ""
|
||||
x = block["X_retrieve"]
|
||||
y = block["Y_pipeline"]
|
||||
lift = block["lift_Y_minus_X"]
|
||||
lines.append(
|
||||
f"| `{qt}`{flag} | {n} "
|
||||
f"| {x['r_at_5']['mean']:.3f} | {y['r_at_5']['mean']:.3f} "
|
||||
f"| {lift['r_at_5']:+.3f} "
|
||||
f"| {x['r_at_10']['mean']:.3f} | {y['r_at_10']['mean']:.3f} "
|
||||
f"| {lift['r_at_10']:+.3f} |"
|
||||
)
|
||||
lines.append("")
|
||||
lines.append("⚠️ = n < 30, low statistical power for that bin.")
|
||||
lines.append("")
|
||||
lines.append("## Notes")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
"- Errors (graph-build failures, malformed rows, etc.) are counted "
|
||||
"as miss for **both** prongs (R@k = 0)."
|
||||
)
|
||||
lines.append(
|
||||
"- Mean is the unweighted row average; CI is bootstrap percentile."
|
||||
)
|
||||
lines.append(
|
||||
"- Architecture lift = mean(Y) − mean(X). The CI of the lift "
|
||||
"itself is not computed here (would require paired bootstrap on "
|
||||
"the (Y_i, X_i) tuples — TODO if needed)."
|
||||
)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--in",
|
||||
dest="input",
|
||||
required=True,
|
||||
help="Path to per-row JSON / JSONL file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report",
|
||||
default=None,
|
||||
help="Output path for markdown report; default: <input>-report.md",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--summary",
|
||||
default=None,
|
||||
help="Output path for aggregated JSON; default: <input>-summary.json",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
input_path = Path(args.input)
|
||||
if not input_path.exists():
|
||||
print(f"[aggregate] ERROR: {input_path} does not exist", file=sys.stderr)
|
||||
return 1
|
||||
rows = load_rows(input_path)
|
||||
if not rows:
|
||||
print(f"[aggregate] WARN: 0 rows loaded from {input_path}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
agg = aggregate(rows)
|
||||
|
||||
summary_path = (
|
||||
Path(args.summary)
|
||||
if args.summary
|
||||
else input_path.with_name(input_path.stem + "-summary.json")
|
||||
)
|
||||
summary_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
json.dump(agg, f, indent=2)
|
||||
|
||||
report_path = (
|
||||
Path(args.report)
|
||||
if args.report
|
||||
else input_path.with_name(input_path.stem + "-report.md")
|
||||
)
|
||||
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
report_path.write_text(format_markdown_report(agg, input_path), encoding="utf-8")
|
||||
|
||||
overall = agg["overall"]
|
||||
x = overall["X_retrieve"]
|
||||
y = overall["Y_pipeline"]
|
||||
lift = overall["lift_Y_minus_X"]
|
||||
print(
|
||||
f"[aggregate] n={overall['n']} errors={overall['n_errors']}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
f"[aggregate] X (retrieve) R@5={x['r_at_5']['mean']:.3f} "
|
||||
f"[{x['r_at_5']['ci_lo']:.3f},{x['r_at_5']['ci_hi']:.3f}] "
|
||||
f"R@10={x['r_at_10']['mean']:.3f}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
f"[aggregate] Y (pipeline) R@5={y['r_at_5']['mean']:.3f} "
|
||||
f"[{y['r_at_5']['ci_lo']:.3f},{y['r_at_5']['ci_hi']:.3f}] "
|
||||
f"R@10={y['r_at_10']['mean']:.3f}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
f"[aggregate] Lift Y − X R@5={lift['r_at_5']:+.3f} "
|
||||
f"R@10={lift['r_at_10']:+.3f}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(f"[aggregate] -> {summary_path}", file=sys.stderr)
|
||||
print(f"[aggregate] -> {report_path}", file=sys.stderr)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
328
bench/lme500/debug_pipeline_loss.py
Normal file
328
bench/lme500/debug_pipeline_loss.py
Normal file
|
|
@ -0,0 +1,328 @@
|
|||
"""bench/lme500/debug_pipeline_loss.py
|
||||
|
||||
Trace WHICH pipeline stage drops the gold session in loss cases
|
||||
(rows where retrieve_recall hits in top-k but recall_for_benchmark does not).
|
||||
|
||||
Usage:
|
||||
python bench/lme500/debug_pipeline_loss.py <question_id> [<question_id> ...]
|
||||
|
||||
For each qid:
|
||||
- Loads the LongMemEval-S row from the pinned dataset.
|
||||
- Builds a fresh per-row store + runtime graph (same shape as the bench).
|
||||
- Runs retrieve_recall to confirm gold sessions are findable by flat cosine.
|
||||
- Runs recall_for_benchmark STAGE BY STAGE, recording at each cut whether the
|
||||
gold record IDs survived.
|
||||
|
||||
Stages traced:
|
||||
Stage 2 — community gate (top-3 communities by centroid cosine)
|
||||
Stage 3 — seeds (top-3 by cosine within gated candidates)
|
||||
Stage 4 — 2-hop spread + rich-club union
|
||||
Stage 5 — final recall_for_benchmark hits
|
||||
|
||||
Output is a per-stage table showing where gold drops.
|
||||
|
||||
Read-only — no src/iai_mcp changes. Calls private helpers _community_gate
|
||||
and _pick_seeds for stage-level inspection (debug-only path).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
||||
|
||||
import numpy as np
|
||||
|
||||
from iai_mcp.embed import embedder_for_store
|
||||
from iai_mcp.pipeline import (
|
||||
_collect_graph_pool,
|
||||
_community_gate,
|
||||
_pick_seeds,
|
||||
recall_for_benchmark,
|
||||
)
|
||||
from iai_mcp.retrieve import build_runtime_graph, recall as retrieve_recall
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import MemoryRecord
|
||||
|
||||
from bench.adapters.longmemeval import LongMemEvalAdapter
|
||||
|
||||
|
||||
def _make_record(content: str, session_id: str, role: str, embedding: list[float]) -> MemoryRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier="episodic",
|
||||
literal_surface=content,
|
||||
aaak_index="",
|
||||
embedding=embedding,
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=2,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=False,
|
||||
never_merge=False,
|
||||
provenance=[],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=["longmemeval", f"role:{role}", f"session:{session_id}"],
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
def find_row(qid: str):
|
||||
adapter = LongMemEvalAdapter()
|
||||
sessions = []
|
||||
question = None
|
||||
answer_session_ids = None
|
||||
qtype = None
|
||||
for lme_session in adapter.load_dataset(split="S"):
|
||||
q = lme_session.queries[0]
|
||||
if q["question_id"] == qid:
|
||||
sessions.append(lme_session)
|
||||
if question is None:
|
||||
question = q["query"]
|
||||
answer_session_ids = set(q.get("relevant_turn_ids", []))
|
||||
qtype = q.get("question_type", "?")
|
||||
return question, qtype, answer_session_ids, sessions
|
||||
|
||||
|
||||
def trace_one(qid: str) -> dict:
|
||||
"""Returns a dict with the stage-by-stage gold survival counts."""
|
||||
print(f"\n{'=' * 78}\n=== qid={qid} ===\n{'=' * 78}", flush=True)
|
||||
question, qtype, gold_session_ids, sessions = find_row(qid)
|
||||
if question is None:
|
||||
print(f" qid={qid} NOT FOUND in dataset", flush=True)
|
||||
return {}
|
||||
|
||||
print(f" type={qtype}", flush=True)
|
||||
print(f" question[0:120]={question[:120]!r}", flush=True)
|
||||
print(f" gold session_ids={gold_session_ids}", flush=True)
|
||||
print(f" haystack sessions={len(sessions)}", flush=True)
|
||||
|
||||
tmp_root = Path(tempfile.mkdtemp(prefix="lme_dbg_"))
|
||||
store_dir = tmp_root / f"row-{qid}"
|
||||
store_dir.mkdir(parents=True, exist_ok=True)
|
||||
store = MemoryStore(path=store_dir / "lancedb")
|
||||
asyncio.run(store.enable_async_writes(coalesce_ms=50, max_batch=128))
|
||||
embedder = embedder_for_store(store)
|
||||
|
||||
id_to_session: dict[UUID, str] = {}
|
||||
gold_record_ids: set[UUID] = set()
|
||||
n_inserted = 0
|
||||
for sess in sessions:
|
||||
for turn in sess.turns:
|
||||
content = str(turn.get("content", "")).strip()
|
||||
if not content:
|
||||
continue
|
||||
vec = embedder.embed(content)
|
||||
rec = _make_record(
|
||||
content=content,
|
||||
session_id=sess.session_id,
|
||||
role=str(turn.get("role", "user")),
|
||||
embedding=vec,
|
||||
)
|
||||
store.insert(rec)
|
||||
id_to_session[rec.id] = sess.session_id
|
||||
if sess.session_id in gold_session_ids:
|
||||
gold_record_ids.add(rec.id)
|
||||
n_inserted += 1
|
||||
|
||||
asyncio.run(store.disable_async_writes())
|
||||
print(f" records inserted: {n_inserted}", flush=True)
|
||||
print(f" gold records: {len(gold_record_ids)}", flush=True)
|
||||
|
||||
graph, assignment, rich_club = build_runtime_graph(store)
|
||||
print(f" graph nodes: {len(graph._nx.nodes)}", flush=True)
|
||||
print(f" communities: {len(assignment.mid_regions)}", flush=True)
|
||||
print(f" rich-club: {len(rich_club)}", flush=True)
|
||||
cue_emb = embedder.embed(question)
|
||||
|
||||
# --- Baseline: retrieve_recall ---
|
||||
resp_x = retrieve_recall(
|
||||
store=store,
|
||||
cue_embedding=cue_emb,
|
||||
cue_text=question,
|
||||
session_id=f"debug-{qid}",
|
||||
budget_tokens=1500,
|
||||
k_hits=10,
|
||||
k_anti=0,
|
||||
)
|
||||
x_ids = [h.record_id for h in resp_x.hits]
|
||||
x_sessions = [id_to_session.get(r, "?") for r in x_ids]
|
||||
x_gold_pos = [i for i, s in enumerate(x_sessions) if s in gold_session_ids]
|
||||
print(f"\n --- retrieve_recall (X) ---", flush=True)
|
||||
print(f" top-10 sessions: {x_sessions}", flush=True)
|
||||
print(f" gold hit positions: {x_gold_pos}", flush=True)
|
||||
|
||||
# --- recall_for_benchmark, stage by stage ---
|
||||
print(f"\n --- recall_for_benchmark (Y) stage-by-stage ---", flush=True)
|
||||
|
||||
gated = _community_gate(cue_emb, assignment, top_n=3)
|
||||
candidates_set: set[UUID] = set()
|
||||
for gc in gated:
|
||||
for cid in assignment.mid_regions.get(gc, []):
|
||||
candidates_set.add(cid)
|
||||
if not candidates_set:
|
||||
candidates_set = {UUID(n) for n in graph._nx.nodes()}
|
||||
print(f" Stage 2 (community gate): EMPTY, fallback to all nodes", flush=True)
|
||||
print(f" Stage 2 (community gate): top-3 communities = {gated}", flush=True)
|
||||
print(f" candidates after gate: {len(candidates_set)}", flush=True)
|
||||
gold_in_gate = gold_record_ids & candidates_set
|
||||
print(f" gold survives gate: {len(gold_in_gate)} / {len(gold_record_ids)}", flush=True)
|
||||
|
||||
centrality: dict[UUID, float] = {}
|
||||
for nid in graph._nx.nodes:
|
||||
n = graph._nx.nodes[nid]
|
||||
if "centrality" in n:
|
||||
try:
|
||||
centrality[UUID(nid)] = float(n["centrality"])
|
||||
except (TypeError, ValueError):
|
||||
centrality[UUID(nid)] = 0.0
|
||||
if not centrality:
|
||||
try:
|
||||
centrality = graph.centrality()
|
||||
except Exception:
|
||||
centrality = {}
|
||||
# (08-01): _pick_seeds now reads from a shared cosine array.
|
||||
# Build the same array the production pipeline builds.
|
||||
pool_ids, pool_embs = _collect_graph_pool(graph, None, store)
|
||||
cue_vec_norm = np.asarray(cue_emb, dtype=np.float32)
|
||||
cn = float(np.linalg.norm(cue_vec_norm))
|
||||
if cn > 0.0:
|
||||
cue_vec_norm = cue_vec_norm / cn
|
||||
if pool_embs.size:
|
||||
shared_cos = (pool_embs @ cue_vec_norm).astype(np.float32)
|
||||
else:
|
||||
shared_cos = np.empty(0, dtype=np.float32)
|
||||
id_to_idx = {rid: i for i, rid in enumerate(pool_ids)}
|
||||
cand_idx = np.array(
|
||||
[id_to_idx[c] for c in candidates_set if c in id_to_idx],
|
||||
dtype=np.int64,
|
||||
)
|
||||
centrality_arr = np.array(
|
||||
[centrality.get(rid, 0.0) for rid in pool_ids],
|
||||
dtype=np.float32,
|
||||
)
|
||||
seed_idx = _pick_seeds(cand_idx, shared_cos, centrality_arr, n=3)
|
||||
seeds = [pool_ids[int(i)] for i in seed_idx]
|
||||
print(f" Stage 3 (seeds, top-3 by cosine in gated): {len(seeds)}", flush=True)
|
||||
seeds_sessions = [id_to_session.get(s, "?") for s in seeds]
|
||||
print(f" seed sessions: {seeds_sessions}", flush=True)
|
||||
gold_in_seeds = gold_record_ids & set(seeds)
|
||||
print(f" gold in seeds: {len(gold_in_seeds)}", flush=True)
|
||||
|
||||
spread = graph.two_hop_neighborhood(seeds, top_k=5)
|
||||
reachable = set(seeds) | set(spread) | set(rich_club)
|
||||
print(f" Stage 4 (spread + rich-club union):", flush=True)
|
||||
print(f" seeds={len(seeds)} spread={len(spread)} rich={len(rich_club)} reachable={len(reachable)}", flush=True)
|
||||
gold_in_reachable = gold_record_ids & reachable
|
||||
print(f" gold in reachable: {len(gold_in_reachable)} / {len(gold_record_ids)}", flush=True)
|
||||
|
||||
resp_y = recall_for_benchmark(
|
||||
store=store,
|
||||
graph=graph,
|
||||
assignment=assignment,
|
||||
rich_club=rich_club,
|
||||
embedder=embedder,
|
||||
cue=question,
|
||||
session_id=f"debug-{qid}",
|
||||
k_hits=10,
|
||||
profile_state=None,
|
||||
turn=0,
|
||||
mode="concept",
|
||||
)
|
||||
y_ids = [h.record_id for h in resp_y.hits]
|
||||
y_sessions = [id_to_session.get(r, "?") for r in y_ids]
|
||||
y_gold_pos = [i for i, s in enumerate(y_sessions) if s in gold_session_ids]
|
||||
print(f" Stage 5 (rank + budget pack):", flush=True)
|
||||
print(f" final hits: {len(y_ids)}", flush=True)
|
||||
print(f" top-10 sessions: {y_sessions}", flush=True)
|
||||
print(f" gold hit positions: {y_gold_pos}", flush=True)
|
||||
|
||||
# ----- Verdict -----
|
||||
# verdict primary signal is whether gold lands in
|
||||
# recall_for_benchmark's top-10 — which is what matters for R@5/R@10.
|
||||
# Stage-2/3/4 stage-by-stage diagnostics still print above (useful when
|
||||
# gold is missed) but they observe the PRIVATE _community_gate /
|
||||
# _pick_seeds path. The redesign (08-CONTEXT.md D-02) makes the
|
||||
# community gate a soft-bias diagnostic rather than a hard filter, so a
|
||||
# "stage_2 missed" diagnostic with gold present in final hits means:
|
||||
# the gate's communities did not include gold, but the cosine top-K
|
||||
# candidate pool did, and Stage 5 ranking surfaced it.
|
||||
print(f"\n --- VERDICT ---", flush=True)
|
||||
if y_gold_pos:
|
||||
print(f" gold present in top-10 (positions {y_gold_pos}) — no_loss", flush=True)
|
||||
if not gold_in_gate:
|
||||
print(f" (gate would have killed it; augmentation rescued)", flush=True)
|
||||
verdict = "no_loss"
|
||||
elif not gold_in_gate:
|
||||
print(f" >>> GOLD KILLED at STAGE 2 (community gate) — augmentation also failed <<<", flush=True)
|
||||
verdict = "stage_2_community_gate"
|
||||
elif not gold_in_reachable:
|
||||
print(f" >>> GOLD KILLED at STAGE 3-4 (seeds + spread) <<<", flush=True)
|
||||
print(f" gold was {len(gold_in_gate)} candidate(s); none became "
|
||||
f"a seed and none was reached within 2 hops of the chosen seeds", flush=True)
|
||||
verdict = "stage_3_4_seeds_or_spread"
|
||||
else:
|
||||
print(f" >>> GOLD KILLED at STAGE 5 (rank + budget pack) <<<", flush=True)
|
||||
print(f" gold was reachable ({len(gold_in_reachable)}) but not in top-10 hits", flush=True)
|
||||
verdict = "stage_5_rank"
|
||||
|
||||
return {
|
||||
"qid": qid,
|
||||
"qtype": qtype,
|
||||
"verdict": verdict,
|
||||
"n_records": n_inserted,
|
||||
"n_communities": len(assignment.mid_regions),
|
||||
"n_rich_club": len(rich_club),
|
||||
"n_gold_records": len(gold_record_ids),
|
||||
"gold_in_gate": len(gold_in_gate),
|
||||
"gold_in_reachable": len(gold_in_reachable),
|
||||
"x_gold_pos": x_gold_pos,
|
||||
"y_gold_pos": y_gold_pos,
|
||||
}
|
||||
|
||||
|
||||
def main(qids: list[str]) -> int:
|
||||
summary = []
|
||||
for qid in qids:
|
||||
try:
|
||||
summary.append(trace_one(qid))
|
||||
except Exception as exc:
|
||||
print(f"\n qid={qid} TRACE FAILED: {type(exc).__name__}: {exc}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
summary.append({"qid": qid, "verdict": "trace_failed"})
|
||||
|
||||
print("\n\n" + "=" * 78)
|
||||
print("SUMMARY")
|
||||
print("=" * 78)
|
||||
print(f"{'qid':16} {'qtype':28} {'verdict':32} gold(gate→reach)")
|
||||
print("-" * 100)
|
||||
for s in summary:
|
||||
if not s:
|
||||
continue
|
||||
gate = s.get("gold_in_gate", "?")
|
||||
reach = s.get("gold_in_reachable", "?")
|
||||
ngold = s.get("n_gold_records", "?")
|
||||
print(
|
||||
f"{s.get('qid', '?'):16} {s.get('qtype', '?'):28} "
|
||||
f"{s.get('verdict', '?'):32} "
|
||||
f"{gate}→{reach} (of {ngold})"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print(__doc__, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
Loading…
Add table
Add a link
Reference in a new issue