Initial release: iai-mcp v0.1.0
Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: XNLLLLH <XNLLLLH@users.noreply.github.com>
This commit is contained in:
commit
f6b876fbe7
332 changed files with 97258 additions and 0 deletions
768
bench/longmemeval_blind.py
Normal file
768
bench/longmemeval_blind.py
Normal file
|
|
@ -0,0 +1,768 @@
|
|||
"""Plan 05-11 blind-run orchestrator — / M-08.
|
||||
|
||||
Runs LongMemEval-S through IAI-MCP's public API (MemoryStore.insert +
|
||||
retrieve.recall) in strict blind mode: no per-dataset tuning, no
|
||||
hyperparameter sweep, no late adjustment after seeing numbers. This is
|
||||
the external honesty axis for Phase 5.
|
||||
|
||||
## Row-level protocol
|
||||
|
||||
One evaluation row in LongMemEval-S contains:
|
||||
|
||||
{ "question", "answer_session_ids" (gold),
|
||||
"haystack_session_ids", "haystack_sessions" (the full history) }
|
||||
|
||||
Per row the orchestrator does:
|
||||
|
||||
1. fresh tmp MemoryStore (per-row isolation; no cross-row leakage)
|
||||
2. enable async writes (Plan 05-10 — keeps RAM bounded on a
|
||||
16GB M1 laptop)
|
||||
3. embed + insert every turn of every haystack session; each record
|
||||
is tagged with ``session:<session_id>`` so the orchestrator can
|
||||
score at the dataset's native session-ID granularity.
|
||||
4. disable async writes (flushes the queue; the store now holds the
|
||||
full haystack).
|
||||
5. build_runtime_graph once (Plan 05-09 cache amortises cold start
|
||||
across rows via the shared runtime graph cache dir).
|
||||
6. call retrieve.recall for the eval query, with k_hits=10.
|
||||
7. compute R@5 / R@10 at session-ID granularity (the standard
|
||||
LongMemEval metric): a retrieved record "hits" if its ``session:``
|
||||
tag is in answer_session_ids. R@k is 1.0 if any top-k hits, else 0.
|
||||
8. measure per-query token cost via bench.tokens counters.
|
||||
|
||||
## CLI
|
||||
|
||||
python bench/longmemeval_blind.py \\
|
||||
--split S \\
|
||||
[--limit N] \\
|
||||
[--granularity {session, turn}] \\
|
||||
[--dataset {cleaned, raw}] \\
|
||||
[--qid-include csv] \\
|
||||
--out /tmp/p11_lme_full.json
|
||||
|
||||
Phase 9 added two methodology-alignment flags:
|
||||
|
||||
--granularity session (default; one record per session,
|
||||
content = "\\n".join(user-only turns))
|
||||
--granularity turn (v1/v2 reproducer; one record per turn)
|
||||
--dataset cleaned (default; xiaowu0162/longmemeval-cleaned)
|
||||
--dataset raw (v1/v2 reproducer; xiaowu0162/longmemeval
|
||||
rev 2ec2a557f339)
|
||||
--qid-include csv optional comma-separated question_ids; when
|
||||
set, only those rows run (used by smoke
|
||||
tests for per-qid baseline verification)
|
||||
|
||||
## Output JSON keys
|
||||
|
||||
{
|
||||
"split": "S",
|
||||
"dataset_id": "xiaowu0162/longmemeval-cleaned" | "xiaowu0162/longmemeval",
|
||||
"revision": "<40-hex>",
|
||||
"granularity": "session" | "turn",
|
||||
"dataset_choice": "cleaned" | "raw",
|
||||
"n_rows": int, # rows actually evaluated
|
||||
"r_at_5": float, # session-ID R@5, mean across rows
|
||||
"r_at_10": float, # session-ID R@10, mean across rows
|
||||
"token_p50": int, # per-query cue-text tokens, median
|
||||
"token_p95": int, # per-query cue-text tokens, p95
|
||||
"session_tokens_mean": float, # mean per-row inserted text tokens
|
||||
# (proxy for the rows' storage footprint)
|
||||
"errors": [{"question_id": str, "error_class": str, "error": str}],
|
||||
"hard_limit": int | null,
|
||||
"note": str
|
||||
}
|
||||
|
||||
## discipline
|
||||
|
||||
The run is ONE-SHOT. If a bug crashes a row, it's logged in ``errors``
|
||||
and counted as a MISS against R@k (not silently dropped). The published
|
||||
number is whatever came out. Disclosures (small-N, hardware limit,
|
||||
English-only embedder, etc.) live in the published bench report and
|
||||
05-11-SUMMARY.md — they don't get folded back into this script.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import statistics
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
# Silence the "UNEXPECTED embeddings.position_ids" noise from
|
||||
# sentence-transformers so the blind-run stderr stays focused on errors.
|
||||
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
||||
|
||||
# IAI-MCP imports — public API only (plan directive).
|
||||
from iai_mcp.embed import Embedder, embedder_for_store
|
||||
from iai_mcp.pipeline import recall_for_benchmark
|
||||
from iai_mcp.retrieve import build_runtime_graph, recall as retrieve_recall
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import MemoryRecord
|
||||
|
||||
# Adapter (ships alongside this script).
|
||||
from bench.adapters.longmemeval import (
|
||||
DATASET_ID,
|
||||
PINNED_REVISION,
|
||||
LMESession,
|
||||
LongMemEvalAdapter,
|
||||
)
|
||||
|
||||
# Token counter (reuses bench/tokens.py three-tier helper).
|
||||
from bench.tokens import _char4_count, _tiktoken_count
|
||||
|
||||
|
||||
def _count_tokens(text: str) -> int:
|
||||
"""Prefer tiktoken-cl100k proxy; fall back to char4."""
|
||||
try:
|
||||
return _tiktoken_count(text)
|
||||
except Exception: # pragma: no cover
|
||||
return _char4_count(text)
|
||||
|
||||
|
||||
def _percentile(xs: list[int], p: float) -> int:
|
||||
if not xs:
|
||||
return 0
|
||||
s = sorted(xs)
|
||||
k = max(0, min(len(s) - 1, int(round((len(s) - 1) * p / 100.0))))
|
||||
return s[k]
|
||||
|
||||
|
||||
def _make_record(
|
||||
content: str,
|
||||
session_id: str,
|
||||
role: str,
|
||||
embedding: list[float],
|
||||
) -> MemoryRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
from uuid import uuid4
|
||||
|
||||
return MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier="episodic",
|
||||
literal_surface=content,
|
||||
aaak_index="",
|
||||
embedding=embedding,
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=2,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=False,
|
||||
never_merge=False,
|
||||
provenance=[],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=[
|
||||
"longmemeval",
|
||||
f"role:{role}",
|
||||
f"session:{session_id}",
|
||||
],
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
def _run_one_row(
|
||||
row_id: str,
|
||||
question: str,
|
||||
question_type: str,
|
||||
answer_session_ids: set[str],
|
||||
sessions: list[LMESession],
|
||||
tmp_root: Path,
|
||||
granularity: str = "turn",
|
||||
embedder_key: str = "bge-small-en-v1.5",
|
||||
) -> dict[str, Any]:
|
||||
"""Execute the per-row protocol. Returns a dict with r_at_5/r_at_10
|
||||
for BOTH retrieve_recall (flat-cosine baseline, matches Phase 5
|
||||
n=30) AND recall_for_benchmark (full graph-native architecture; Phase
|
||||
8 entry-point split), token counts plus timing info. Raises
|
||||
only on programmer errors; dataset/runtime errors are caught by the
|
||||
caller.
|
||||
|
||||
bench/lme500 protocol: prong X = retrieve_recall, prong Y =
|
||||
recall_for_benchmark. Both share the same insert phase + retrieved-set
|
||||
mapping, so the architecture-vs-baseline delta is attributable to
|
||||
the recall function only, not retrieval-side variance.
|
||||
|
||||
``granularity`` controls corpus construction.
|
||||
"turn" -> one record per turn (v1/v2 baseline; ~500 records/row)
|
||||
"session" -> one record per session whose content is
|
||||
"\\n".join(user-only turns), matching mempalace's
|
||||
reference verbatim (~53 records/row).
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# Fresh store in a per-row tmp dir.
|
||||
store_dir = tmp_root / f"row-{row_id}"
|
||||
store_dir.mkdir(parents=True, exist_ok=True)
|
||||
store = MemoryStore(path=store_dir / "lancedb")
|
||||
|
||||
# async writes: coalesce LanceDB appends across the row.
|
||||
# enable_async_writes is a coroutine — drive it from a fresh loop so
|
||||
# the surrounding orchestrator stays sync.
|
||||
asyncio.run(store.enable_async_writes(coalesce_ms=50, max_batch=128))
|
||||
|
||||
# count inserted tokens as a rough storage footprint.
|
||||
inserted_text_tokens = 0
|
||||
|
||||
# route through the explicit registry key so the
|
||||
# embedder ablation experiment can swap to all-MiniLM-L6-v2 without
|
||||
# touching the production-default resolver (embedder_for_store kept
|
||||
# imported for backward-compat; not called on this path).
|
||||
embedder = Embedder(model_key=embedder_key)
|
||||
_ = embedder_for_store # silence unused-import warning when the prod path is bypassed
|
||||
|
||||
# --------- INSERT phase ---------
|
||||
# One pass over all haystack sessions for this row. Each MemoryRecord is
|
||||
# tagged with its session_id so R@k can score at the dataset's native
|
||||
# session granularity. splits this into two paths:
|
||||
# - "turn" (v1/v2 baseline; one record per turn, both roles)
|
||||
# - "session" (mempalace-aligned; one record per session, user-only
|
||||
# turns joined with "\n"; ~10x fewer records per row)
|
||||
id_to_session: dict[str, str] = {} # record_id.hex -> session_id
|
||||
if granularity == "session":
|
||||
# Session-granularity (D-01, mempalace-aligned): ONE record per
|
||||
# session, content = "\n".join(user-only turns). Skip sessions
|
||||
# with no user turns. Verbatim shape match with mempalace's
|
||||
# benchmarks/longmemeval_bench.py reference loop.
|
||||
for sess in sessions:
|
||||
user_turns = [
|
||||
str(turn.get("content", "")).strip()
|
||||
for turn in sess.turns
|
||||
if str(turn.get("role", "user")) == "user"
|
||||
and str(turn.get("content", "")).strip()
|
||||
]
|
||||
if not user_turns:
|
||||
continue
|
||||
doc_text = "\n".join(user_turns)
|
||||
vec = embedder.embed(doc_text)
|
||||
rec = _make_record(
|
||||
content=doc_text,
|
||||
session_id=sess.session_id,
|
||||
role="user",
|
||||
embedding=vec,
|
||||
)
|
||||
store.insert(rec)
|
||||
id_to_session[str(rec.id)] = sess.session_id
|
||||
inserted_text_tokens += _count_tokens(doc_text)
|
||||
else:
|
||||
# Turn-granularity (v1/v2 baseline; bytes-identical loop body).
|
||||
for sess in sessions:
|
||||
for turn in sess.turns:
|
||||
content = str(turn.get("content", "")).strip()
|
||||
if not content:
|
||||
continue
|
||||
vec = embedder.embed(content)
|
||||
rec = _make_record(
|
||||
content=content,
|
||||
session_id=sess.session_id,
|
||||
role=str(turn.get("role", "user")),
|
||||
embedding=vec,
|
||||
)
|
||||
store.insert(rec)
|
||||
id_to_session[str(rec.id)] = sess.session_id
|
||||
inserted_text_tokens += _count_tokens(content)
|
||||
|
||||
# Flush the async queue before recall. disable_async_writes is a
|
||||
# coroutine too — drive from a fresh loop.
|
||||
asyncio.run(store.disable_async_writes())
|
||||
t_after_insert = time.time()
|
||||
|
||||
# --------- Build runtime graph (Plan 05-09 cache warms cold-start) ---------
|
||||
# bench/lme500: capture the (graph, assignment, rich_club) tuple so
|
||||
# recall_for_benchmark (prong Y) can reuse it. retrieve_recall (prong X)
|
||||
# is unaffected by graph build success/failure.
|
||||
graph = None
|
||||
assignment = None
|
||||
rich_club = None
|
||||
try:
|
||||
graph, assignment, rich_club = build_runtime_graph(store)
|
||||
except Exception as exc: # pragma: no cover — cache helpers should be robust
|
||||
# Don't fail the row on graph build; retrieve_recall is still
|
||||
# callable from the flat store. recall_for_benchmark will be skipped
|
||||
# for this row and counted as miss for the Y prong.
|
||||
print(
|
||||
f"[LME] row={row_id} build_runtime_graph failed: "
|
||||
f"{type(exc).__name__}: {exc}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
t_after_graph = time.time()
|
||||
|
||||
# --------- Prong X: retrieve_recall (flat-cosine, baseline) ---------
|
||||
cue_embedding = embedder.embed(question)
|
||||
resp_x = retrieve_recall(
|
||||
store=store,
|
||||
cue_embedding=cue_embedding,
|
||||
cue_text=question,
|
||||
session_id=f"lme-{row_id}",
|
||||
budget_tokens=1500,
|
||||
k_hits=10,
|
||||
k_anti=0,
|
||||
)
|
||||
t_after_x = time.time()
|
||||
|
||||
# --------- Prong Y: recall_for_benchmark (full graph-native architecture) ---------
|
||||
# entry-point split: bench harness uses the top-K contract
|
||||
# (k_hits=10, no budget_tokens). mode="concept" preserved verbatim — the
|
||||
# bench is concept-shaped per BENCH_PROTOCOL_lme500.md and the D-02
|
||||
# `_gate_bias_for_mode("concept") == 0.1` bias is what v2 measurements observe.
|
||||
resp_y = None
|
||||
pipeline_error: str | None = None
|
||||
if graph is not None:
|
||||
try:
|
||||
resp_y = recall_for_benchmark(
|
||||
store=store,
|
||||
graph=graph,
|
||||
assignment=assignment,
|
||||
rich_club=rich_club,
|
||||
embedder=embedder,
|
||||
cue=question,
|
||||
session_id=f"lme-{row_id}",
|
||||
k_hits=10,
|
||||
profile_state=None,
|
||||
turn=0,
|
||||
mode="concept",
|
||||
)
|
||||
except Exception as exc:
|
||||
pipeline_error = f"{type(exc).__name__}: {str(exc)[:200]}"
|
||||
print(
|
||||
f"[LME] row={row_id} recall_for_benchmark failed: "
|
||||
f"{pipeline_error}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
pipeline_error = "graph_build_failed"
|
||||
t_after_y = time.time()
|
||||
|
||||
def _retrieved_session_ids(resp) -> list[str]:
|
||||
if resp is None:
|
||||
return []
|
||||
out: list[str] = []
|
||||
for hit in resp.hits:
|
||||
sid = id_to_session.get(str(hit.record_id))
|
||||
if sid is not None:
|
||||
out.append(sid)
|
||||
return out
|
||||
|
||||
sids_x = _retrieved_session_ids(resp_x)
|
||||
sids_y = _retrieved_session_ids(resp_y)
|
||||
|
||||
# LongMemEval-standard R@k at session-ID granularity: hit-at-k.
|
||||
# R@k = 1.0 if any of the top-k retrieved records belongs to a gold
|
||||
# session, else 0.0. Aggregated across rows by the caller.
|
||||
def _hit_at_k(sids: list[str], k: int) -> float:
|
||||
top = sids[:k]
|
||||
return 1.0 if any(s in answer_session_ids for s in top) else 0.0
|
||||
|
||||
r5_x = _hit_at_k(sids_x, 5)
|
||||
r10_x = _hit_at_k(sids_x, 10)
|
||||
r5_y = _hit_at_k(sids_y, 5) if resp_y is not None else 0.0
|
||||
r10_y = _hit_at_k(sids_y, 10) if resp_y is not None else 0.0
|
||||
|
||||
query_tokens = _count_tokens(question)
|
||||
|
||||
return {
|
||||
"question_id": row_id,
|
||||
"question_type": question_type,
|
||||
# Prong X — retrieve_recall (flat-cosine baseline, line-by-line)
|
||||
"r_at_5_retrieve": r5_x,
|
||||
"r_at_10_retrieve": r10_x,
|
||||
# Prong Y — recall_for_benchmark (full graph-native pipeline; D-07)
|
||||
"r_at_5_pipeline": r5_y,
|
||||
"r_at_10_pipeline": r10_y,
|
||||
"pipeline_error": pipeline_error,
|
||||
# Shared
|
||||
"query_tokens": query_tokens,
|
||||
"inserted_text_tokens": inserted_text_tokens,
|
||||
"n_haystack_sessions": len(sessions),
|
||||
"n_turns_inserted": len(id_to_session),
|
||||
"timing_seconds": {
|
||||
"insert": round(t_after_insert - t0, 2),
|
||||
"graph": round(t_after_graph - t_after_insert, 2),
|
||||
"recall_retrieve": round(t_after_x - t_after_graph, 2),
|
||||
"recall_pipeline": round(t_after_y - t_after_x, 2),
|
||||
"total": round(t_after_y - t0, 2),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
default="S",
|
||||
choices=["S", "M", "oracle"],
|
||||
help="LongMemEval split (Plan 05-11 runs S)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"practical-cap on rows evaluated. LongMemEval-S = 500 rows; "
|
||||
"at ~500 turns/row and 11ms/embed on a 16GB M1 laptop, the "
|
||||
"full 500-row run is multi-hour. --limit lets the blind pilot "
|
||||
"finish; the SUMMARY discloses the cap honestly."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out",
|
||||
default="/tmp/p11_lme_full.json",
|
||||
help="output JSON path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
default=None,
|
||||
help=(
|
||||
"JSONL checkpoint path for crash-resume; default = <out>.jsonl. "
|
||||
"Each completed (or errored) row is appended with fsync as one "
|
||||
"JSON line. On restart, rows whose question_id already appears "
|
||||
"in the checkpoint are skipped."
|
||||
),
|
||||
)
|
||||
# granularity flag with mempalace-aligned default.
|
||||
parser.add_argument(
|
||||
"--granularity",
|
||||
choices=["session", "turn"],
|
||||
default="session",
|
||||
help=(
|
||||
"corpus-construction granularity. "
|
||||
"'session' (default, v3): one record per session, "
|
||||
"content = '\\n'.join(user-only turns) — matches mempalace's "
|
||||
"reference. 'turn': one record per turn (v1/v2 baseline; "
|
||||
"use with --dataset raw to reproduce v2's 0.956)."
|
||||
),
|
||||
)
|
||||
# dataset choice flag with mempalace-aligned default.
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
choices=["cleaned", "raw"],
|
||||
default="cleaned",
|
||||
help=(
|
||||
"dataset variant. 'cleaned' (default, v3): "
|
||||
"xiaowu0162/longmemeval-cleaned, SHA pinned via repo_info(). "
|
||||
"'raw' (v1/v2 baseline): xiaowu0162/longmemeval rev "
|
||||
"2ec2a557f339... — use with --granularity turn to reproduce "
|
||||
"v2's 0.956."
|
||||
),
|
||||
)
|
||||
# Step B: per-qid filter for the v2-baseline
|
||||
# smoke reproducer. Applied AFTER --limit so a future caller passing
|
||||
# both flags gets a deterministic intersection (limit narrows by row
|
||||
# count, qid-include narrows by id). Default None preserves v1/v2 behaviour.
|
||||
parser.add_argument(
|
||||
"--qid-include",
|
||||
default=None,
|
||||
help=(
|
||||
"comma-separated list of question_ids; if set, only these "
|
||||
"rows run (used by smoke tests for per-qid baseline "
|
||||
"verification). Applied after --limit."
|
||||
),
|
||||
)
|
||||
# bench-only embedder swap. Default preserves v3
|
||||
# baseline (bge-small-en-v1.5). all-MiniLM-L6-v2 is mempalace's ChromaDB
|
||||
# default — used for the embedder-axis ablation in v3.1. Production
|
||||
# embedder is unchanged regardless of this flag (English-Only Brain lock
|
||||
# from / Plan 05-08; the Embedder.__init__ kwarg is the only
|
||||
# entry point that surfaces the registry's all-MiniLM-L6-v2 entry).
|
||||
parser.add_argument(
|
||||
"--embedder",
|
||||
choices=["bge-small-en-v1.5", "all-MiniLM-L6-v2"],
|
||||
default="bge-small-en-v1.5",
|
||||
help=(
|
||||
"embedder model_key. 'bge-small-en-v1.5' (default, v3 "
|
||||
"baseline) routes via the production English-only embedder. "
|
||||
"'all-MiniLM-L6-v2' (Phase 9.1 ablation) is mempalace's "
|
||||
"ChromaDB default — bench-only swap, production unchanged."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
print(
|
||||
f"[LME] blind run starting "
|
||||
f"split={args.split} limit={args.limit} "
|
||||
f"granularity={args.granularity} dataset={args.dataset} "
|
||||
f"embedder={args.embedder} "
|
||||
f"out={args.out}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# branch the adapter on --dataset.
|
||||
if args.dataset == "cleaned":
|
||||
from bench.adapters.longmemeval_cleaned import (
|
||||
CLEANED_DATASET_ID,
|
||||
CleanedLongMemEvalAdapter,
|
||||
)
|
||||
adapter = CleanedLongMemEvalAdapter()
|
||||
dataset_id_emit = CLEANED_DATASET_ID
|
||||
revision_emit = adapter.revision
|
||||
else:
|
||||
adapter = LongMemEvalAdapter()
|
||||
dataset_id_emit = DATASET_ID
|
||||
revision_emit = PINNED_REVISION
|
||||
# Adapter yields one LMESession per haystack session, but the
|
||||
# blind-run protocol needs rows (one question + all its haystack
|
||||
# sessions). Group by question_id (carried inside queries[0]).
|
||||
grouped: dict[str, dict[str, Any]] = {}
|
||||
row_order: list[str] = []
|
||||
for lme_session in adapter.load_dataset(split=args.split):
|
||||
q = lme_session.queries[0]
|
||||
qid = q["question_id"]
|
||||
if qid not in grouped:
|
||||
grouped[qid] = {
|
||||
"question": q["query"],
|
||||
"question_type": q.get("question_type", "unknown"),
|
||||
"answer_session_ids": set(q.get("relevant_turn_ids", [])),
|
||||
"sessions": [],
|
||||
}
|
||||
row_order.append(qid)
|
||||
grouped[qid]["sessions"].append(lme_session)
|
||||
|
||||
if args.limit is not None:
|
||||
row_order = row_order[: args.limit]
|
||||
|
||||
# Step B: --qid-include filter applied AFTER
|
||||
# --limit so a future caller passing both flags gets a deterministic
|
||||
# intersection. The default None path is a no-op for backward compat.
|
||||
if args.qid_include is not None:
|
||||
wanted = {q.strip() for q in str(args.qid_include).split(",") if q.strip()}
|
||||
row_order = [qid for qid in row_order if qid in wanted]
|
||||
print(
|
||||
f"[LME] qid-include filter: kept {len(row_order)} of "
|
||||
f"{len(wanted)} requested qids",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
tmp_root = Path(tempfile.mkdtemp(prefix="lme_blind_"))
|
||||
print(f"[LME] per-row stores rooted at {tmp_root}", file=sys.stderr, flush=True)
|
||||
|
||||
per_row: list[dict[str, Any]] = []
|
||||
errors: list[dict[str, str]] = []
|
||||
# bench/lme500: track BOTH prongs (X = retrieve_recall, Y = recall_for_benchmark).
|
||||
r5_x_values: list[float] = []
|
||||
r10_x_values: list[float] = []
|
||||
r5_y_values: list[float] = []
|
||||
r10_y_values: list[float] = []
|
||||
query_tokens: list[int] = []
|
||||
session_tokens: list[int] = []
|
||||
|
||||
# bench/lme500: per-row JSONL checkpoint for crash resume.
|
||||
# Each row's full result is appended with flush + fsync, so a kill at
|
||||
# row N preserves rows 1..N-1 fully. Restart skips rows already in the
|
||||
# checkpoint (matched by question_id).
|
||||
checkpoint_path = Path(args.checkpoint) if args.checkpoint else Path(str(args.out) + ".jsonl")
|
||||
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
completed_ids: set[str] = set()
|
||||
if checkpoint_path.exists():
|
||||
with open(checkpoint_path, "r", encoding="utf-8") as cp_f:
|
||||
for line in cp_f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
rec = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
print(
|
||||
f"[LME] WARN: skipping corrupt checkpoint line: {line[:80]!r}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
continue
|
||||
qid = rec.get("question_id")
|
||||
if not qid:
|
||||
continue
|
||||
completed_ids.add(qid)
|
||||
if "error" in rec and isinstance(rec.get("error"), dict):
|
||||
# Resumed error row: count as full miss for both prongs.
|
||||
errors.append(
|
||||
{
|
||||
"question_id": qid,
|
||||
"error_class": rec["error"].get("error_class", "Unknown"),
|
||||
"error": rec["error"].get("error", ""),
|
||||
}
|
||||
)
|
||||
r5_x_values.append(0.0)
|
||||
r10_x_values.append(0.0)
|
||||
r5_y_values.append(0.0)
|
||||
r10_y_values.append(0.0)
|
||||
query_tokens.append(0)
|
||||
session_tokens.append(0)
|
||||
else:
|
||||
# Resumed success row.
|
||||
per_row.append(rec)
|
||||
r5_x_values.append(float(rec.get("r_at_5_retrieve", 0.0)))
|
||||
r10_x_values.append(float(rec.get("r_at_10_retrieve", 0.0)))
|
||||
r5_y_values.append(float(rec.get("r_at_5_pipeline", 0.0)))
|
||||
r10_y_values.append(float(rec.get("r_at_10_pipeline", 0.0)))
|
||||
query_tokens.append(int(rec.get("query_tokens", 0)))
|
||||
session_tokens.append(int(rec.get("inserted_text_tokens", 0)))
|
||||
if completed_ids:
|
||||
print(
|
||||
f"[LME] resume: {len(completed_ids)} rows already in checkpoint "
|
||||
f"{checkpoint_path}; processing {len(row_order) - len(completed_ids)} remaining",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"[LME] checkpoint: writing per-row durable JSONL to {checkpoint_path}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _checkpoint_append(rec: dict[str, Any]) -> None:
|
||||
"""Append one row record to the checkpoint, flush+fsync for durability."""
|
||||
with open(checkpoint_path, "a", encoding="utf-8") as cp_a:
|
||||
cp_a.write(json.dumps(rec) + "\n")
|
||||
cp_a.flush()
|
||||
os.fsync(cp_a.fileno())
|
||||
|
||||
run_t0 = time.time()
|
||||
for i, qid in enumerate(row_order):
|
||||
if qid in completed_ids:
|
||||
continue
|
||||
row = grouped[qid]
|
||||
try:
|
||||
res = _run_one_row(
|
||||
row_id=qid,
|
||||
question=row["question"],
|
||||
question_type=row["question_type"],
|
||||
answer_session_ids=row["answer_session_ids"],
|
||||
sessions=row["sessions"],
|
||||
tmp_root=tmp_root,
|
||||
granularity=args.granularity,
|
||||
embedder_key=args.embedder,
|
||||
)
|
||||
per_row.append(res)
|
||||
r5_x_values.append(res["r_at_5_retrieve"])
|
||||
r10_x_values.append(res["r_at_10_retrieve"])
|
||||
r5_y_values.append(res["r_at_5_pipeline"])
|
||||
r10_y_values.append(res["r_at_10_pipeline"])
|
||||
query_tokens.append(res["query_tokens"])
|
||||
session_tokens.append(res["inserted_text_tokens"])
|
||||
_checkpoint_append(res)
|
||||
elapsed = time.time() - run_t0
|
||||
print(
|
||||
f"[LME] row {i+1}/{len(row_order)} qid={qid} "
|
||||
f"qtype={res['question_type']} "
|
||||
f"R@5_x={res['r_at_5_retrieve']:.0f} R@5_y={res['r_at_5_pipeline']:.0f} "
|
||||
f"R@10_x={res['r_at_10_retrieve']:.0f} R@10_y={res['r_at_10_pipeline']:.0f} "
|
||||
f"t_row={res['timing_seconds']['total']:.1f}s "
|
||||
f"t_total={elapsed:.1f}s",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
# T-05-11-04 mitigation: log + count as miss, do
|
||||
# NOT silently drop.
|
||||
err_payload = {
|
||||
"error_class": type(exc).__name__,
|
||||
"error": str(exc)[:500],
|
||||
}
|
||||
errors.append({"question_id": qid, **err_payload})
|
||||
# Counted as a full miss for both prongs — preserves
|
||||
# "count against R@5 as 0" from the plan text.
|
||||
r5_x_values.append(0.0)
|
||||
r10_x_values.append(0.0)
|
||||
r5_y_values.append(0.0)
|
||||
r10_y_values.append(0.0)
|
||||
query_tokens.append(0)
|
||||
session_tokens.append(0)
|
||||
# Persist the error row to checkpoint so a restart skips it.
|
||||
_checkpoint_append(
|
||||
{
|
||||
"question_id": qid,
|
||||
"question_type": row.get("question_type", "unknown"),
|
||||
"error": err_payload,
|
||||
}
|
||||
)
|
||||
print(
|
||||
f"[LME] ERROR row={qid}: {type(exc).__name__}: {exc}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
finally:
|
||||
# Free disk aggressively — many rows × ~500 turns per store
|
||||
# adds up even on 64GB.
|
||||
row_dir = tmp_root / f"row-{qid}"
|
||||
if row_dir.exists():
|
||||
shutil.rmtree(row_dir, ignore_errors=True)
|
||||
|
||||
shutil.rmtree(tmp_root, ignore_errors=True)
|
||||
|
||||
def _mean(xs: list[float]) -> float:
|
||||
return (sum(xs) / len(xs)) if xs else 0.0
|
||||
|
||||
out = {
|
||||
"split": args.split,
|
||||
"dataset_id": dataset_id_emit,
|
||||
"revision": revision_emit,
|
||||
# reproducibility fields:
|
||||
"granularity": args.granularity,
|
||||
"dataset_choice": args.dataset,
|
||||
# embedder identity pinned for v3.1 ablation reproducibility.
|
||||
# Default "bge-small-en-v1.5" reproduces v3 baseline; "all-MiniLM-L6-v2"
|
||||
# is the embedder-axis ablation toggle (mempalace ChromaDB default).
|
||||
"embedder_model_key": args.embedder,
|
||||
"embedder_hf_id": Embedder(model_key=args.embedder).model_name,
|
||||
"n_rows": len(row_order),
|
||||
# Prong X — retrieve_recall (flat-cosine baseline, line-by-line)
|
||||
"r_at_5_retrieve": _mean(r5_x_values),
|
||||
"r_at_10_retrieve": _mean(r10_x_values),
|
||||
# Prong Y — recall_for_benchmark (full graph-native architecture; D-07)
|
||||
"r_at_5_pipeline": _mean(r5_y_values),
|
||||
"r_at_10_pipeline": _mean(r10_y_values),
|
||||
# Architecture lift (Y - X)
|
||||
"r_at_5_lift": _mean(r5_y_values) - _mean(r5_x_values),
|
||||
"r_at_10_lift": _mean(r10_y_values) - _mean(r10_x_values),
|
||||
"token_p50": _percentile(query_tokens, 50),
|
||||
"token_p95": _percentile(query_tokens, 95),
|
||||
"session_tokens_mean": (
|
||||
statistics.fmean(session_tokens) if session_tokens else 0.0
|
||||
),
|
||||
"errors": errors,
|
||||
"hard_limit": args.limit,
|
||||
"metric_def": (
|
||||
"Session-ID hit-at-k: R@k = 1.0 if any of top-k retrieved records "
|
||||
"belongs to a gold session_id, else 0.0 (LongMemEval standard)."
|
||||
),
|
||||
"per_row": per_row,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"total_wall_seconds": round(time.time() - run_t0, 2),
|
||||
}
|
||||
|
||||
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(args.out, "w", encoding="utf-8") as f:
|
||||
json.dump(out, f, indent=2)
|
||||
|
||||
print(
|
||||
f"[LME] DONE n_rows={out['n_rows']} "
|
||||
f"R@5_retrieve={out['r_at_5_retrieve']:.3f} "
|
||||
f"R@5_pipeline={out['r_at_5_pipeline']:.3f} "
|
||||
f"lift_R@5={out['r_at_5_lift']:+.3f} "
|
||||
f"R@10_retrieve={out['r_at_10_retrieve']:.3f} "
|
||||
f"R@10_pipeline={out['r_at_10_pipeline']:.3f} "
|
||||
f"lift_R@10={out['r_at_10_lift']:+.3f} "
|
||||
f"errors={len(errors)} -> {args.out}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Loading…
Add table
Add a link
Reference in a new issue