Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: XNLLLLH <XNLLLLH@users.noreply.github.com>
768 lines
30 KiB
Python
768 lines
30 KiB
Python
"""Plan 05-11 blind-run orchestrator — / M-08.
|
||
|
||
Runs LongMemEval-S through IAI-MCP's public API (MemoryStore.insert +
|
||
retrieve.recall) in strict blind mode: no per-dataset tuning, no
|
||
hyperparameter sweep, no late adjustment after seeing numbers. This is
|
||
the external honesty axis for Phase 5.
|
||
|
||
## Row-level protocol
|
||
|
||
One evaluation row in LongMemEval-S contains:
|
||
|
||
{ "question", "answer_session_ids" (gold),
|
||
"haystack_session_ids", "haystack_sessions" (the full history) }
|
||
|
||
Per row the orchestrator does:
|
||
|
||
1. fresh tmp MemoryStore (per-row isolation; no cross-row leakage)
|
||
2. enable async writes (Plan 05-10 — keeps RAM bounded on a
|
||
16GB M1 laptop)
|
||
3. embed + insert every turn of every haystack session; each record
|
||
is tagged with ``session:<session_id>`` so the orchestrator can
|
||
score at the dataset's native session-ID granularity.
|
||
4. disable async writes (flushes the queue; the store now holds the
|
||
full haystack).
|
||
5. build_runtime_graph once (Plan 05-09 cache amortises cold start
|
||
across rows via the shared runtime graph cache dir).
|
||
6. call retrieve.recall for the eval query, with k_hits=10.
|
||
7. compute R@5 / R@10 at session-ID granularity (the standard
|
||
LongMemEval metric): a retrieved record "hits" if its ``session:``
|
||
tag is in answer_session_ids. R@k is 1.0 if any top-k hits, else 0.
|
||
8. measure per-query token cost via bench.tokens counters.
|
||
|
||
## CLI
|
||
|
||
python bench/longmemeval_blind.py \\
|
||
--split S \\
|
||
[--limit N] \\
|
||
[--granularity {session, turn}] \\
|
||
[--dataset {cleaned, raw}] \\
|
||
[--qid-include csv] \\
|
||
--out /tmp/p11_lme_full.json
|
||
|
||
Phase 9 added two methodology-alignment flags:
|
||
|
||
--granularity session (default; one record per session,
|
||
content = "\\n".join(user-only turns))
|
||
--granularity turn (v1/v2 reproducer; one record per turn)
|
||
--dataset cleaned (default; xiaowu0162/longmemeval-cleaned)
|
||
--dataset raw (v1/v2 reproducer; xiaowu0162/longmemeval
|
||
rev 2ec2a557f339)
|
||
--qid-include csv optional comma-separated question_ids; when
|
||
set, only those rows run (used by smoke
|
||
tests for per-qid baseline verification)
|
||
|
||
## Output JSON keys
|
||
|
||
{
|
||
"split": "S",
|
||
"dataset_id": "xiaowu0162/longmemeval-cleaned" | "xiaowu0162/longmemeval",
|
||
"revision": "<40-hex>",
|
||
"granularity": "session" | "turn",
|
||
"dataset_choice": "cleaned" | "raw",
|
||
"n_rows": int, # rows actually evaluated
|
||
"r_at_5": float, # session-ID R@5, mean across rows
|
||
"r_at_10": float, # session-ID R@10, mean across rows
|
||
"token_p50": int, # per-query cue-text tokens, median
|
||
"token_p95": int, # per-query cue-text tokens, p95
|
||
"session_tokens_mean": float, # mean per-row inserted text tokens
|
||
# (proxy for the rows' storage footprint)
|
||
"errors": [{"question_id": str, "error_class": str, "error": str}],
|
||
"hard_limit": int | null,
|
||
"note": str
|
||
}
|
||
|
||
## discipline
|
||
|
||
The run is ONE-SHOT. If a bug crashes a row, it's logged in ``errors``
|
||
and counted as a MISS against R@k (not silently dropped). The published
|
||
number is whatever came out. Disclosures (small-N, hardware limit,
|
||
English-only embedder, etc.) live in the published bench report and
|
||
05-11-SUMMARY.md — they don't get folded back into this script.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import shutil
|
||
import statistics
|
||
import sys
|
||
import tempfile
|
||
import time
|
||
import traceback
|
||
from collections import defaultdict
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import Any
|
||
from uuid import UUID
|
||
|
||
# Silence the "UNEXPECTED embeddings.position_ids" noise from
|
||
# sentence-transformers so the blind-run stderr stays focused on errors.
|
||
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
||
|
||
# IAI-MCP imports — public API only (plan directive).
|
||
from iai_mcp.embed import Embedder, embedder_for_store
|
||
from iai_mcp.pipeline import recall_for_benchmark
|
||
from iai_mcp.retrieve import build_runtime_graph, recall as retrieve_recall
|
||
from iai_mcp.store import MemoryStore
|
||
from iai_mcp.types import MemoryRecord
|
||
|
||
# Adapter (ships alongside this script).
|
||
from bench.adapters.longmemeval import (
|
||
DATASET_ID,
|
||
PINNED_REVISION,
|
||
LMESession,
|
||
LongMemEvalAdapter,
|
||
)
|
||
|
||
# Token counter (reuses bench/tokens.py three-tier helper).
|
||
from bench.tokens import _char4_count, _tiktoken_count
|
||
|
||
|
||
def _count_tokens(text: str) -> int:
|
||
"""Prefer tiktoken-cl100k proxy; fall back to char4."""
|
||
try:
|
||
return _tiktoken_count(text)
|
||
except Exception: # pragma: no cover
|
||
return _char4_count(text)
|
||
|
||
|
||
def _percentile(xs: list[int], p: float) -> int:
|
||
if not xs:
|
||
return 0
|
||
s = sorted(xs)
|
||
k = max(0, min(len(s) - 1, int(round((len(s) - 1) * p / 100.0))))
|
||
return s[k]
|
||
|
||
|
||
def _make_record(
|
||
content: str,
|
||
session_id: str,
|
||
role: str,
|
||
embedding: list[float],
|
||
) -> MemoryRecord:
|
||
now = datetime.now(timezone.utc)
|
||
from uuid import uuid4
|
||
|
||
return MemoryRecord(
|
||
id=uuid4(),
|
||
tier="episodic",
|
||
literal_surface=content,
|
||
aaak_index="",
|
||
embedding=embedding,
|
||
community_id=None,
|
||
centrality=0.0,
|
||
detail_level=2,
|
||
pinned=False,
|
||
stability=0.0,
|
||
difficulty=0.0,
|
||
last_reviewed=None,
|
||
never_decay=False,
|
||
never_merge=False,
|
||
provenance=[],
|
||
created_at=now,
|
||
updated_at=now,
|
||
tags=[
|
||
"longmemeval",
|
||
f"role:{role}",
|
||
f"session:{session_id}",
|
||
],
|
||
language="en",
|
||
)
|
||
|
||
|
||
def _run_one_row(
|
||
row_id: str,
|
||
question: str,
|
||
question_type: str,
|
||
answer_session_ids: set[str],
|
||
sessions: list[LMESession],
|
||
tmp_root: Path,
|
||
granularity: str = "turn",
|
||
embedder_key: str = "bge-small-en-v1.5",
|
||
) -> dict[str, Any]:
|
||
"""Execute the per-row protocol. Returns a dict with r_at_5/r_at_10
|
||
for BOTH retrieve_recall (flat-cosine baseline, matches Phase 5
|
||
n=30) AND recall_for_benchmark (full graph-native architecture; Phase
|
||
8 entry-point split), token counts plus timing info. Raises
|
||
only on programmer errors; dataset/runtime errors are caught by the
|
||
caller.
|
||
|
||
bench/lme500 protocol: prong X = retrieve_recall, prong Y =
|
||
recall_for_benchmark. Both share the same insert phase + retrieved-set
|
||
mapping, so the architecture-vs-baseline delta is attributable to
|
||
the recall function only, not retrieval-side variance.
|
||
|
||
``granularity`` controls corpus construction.
|
||
"turn" -> one record per turn (v1/v2 baseline; ~500 records/row)
|
||
"session" -> one record per session whose content is
|
||
"\\n".join(user-only turns), matching mempalace's
|
||
reference verbatim (~53 records/row).
|
||
"""
|
||
t0 = time.time()
|
||
|
||
# Fresh store in a per-row tmp dir.
|
||
store_dir = tmp_root / f"row-{row_id}"
|
||
store_dir.mkdir(parents=True, exist_ok=True)
|
||
store = MemoryStore(path=store_dir / "lancedb")
|
||
|
||
# async writes: coalesce LanceDB appends across the row.
|
||
# enable_async_writes is a coroutine — drive it from a fresh loop so
|
||
# the surrounding orchestrator stays sync.
|
||
asyncio.run(store.enable_async_writes(coalesce_ms=50, max_batch=128))
|
||
|
||
# count inserted tokens as a rough storage footprint.
|
||
inserted_text_tokens = 0
|
||
|
||
# route through the explicit registry key so the
|
||
# embedder ablation experiment can swap to all-MiniLM-L6-v2 without
|
||
# touching the production-default resolver (embedder_for_store kept
|
||
# imported for backward-compat; not called on this path).
|
||
embedder = Embedder(model_key=embedder_key)
|
||
_ = embedder_for_store # silence unused-import warning when the prod path is bypassed
|
||
|
||
# --------- INSERT phase ---------
|
||
# One pass over all haystack sessions for this row. Each MemoryRecord is
|
||
# tagged with its session_id so R@k can score at the dataset's native
|
||
# session granularity. splits this into two paths:
|
||
# - "turn" (v1/v2 baseline; one record per turn, both roles)
|
||
# - "session" (mempalace-aligned; one record per session, user-only
|
||
# turns joined with "\n"; ~10x fewer records per row)
|
||
id_to_session: dict[str, str] = {} # record_id.hex -> session_id
|
||
if granularity == "session":
|
||
# Session-granularity (D-01, mempalace-aligned): ONE record per
|
||
# session, content = "\n".join(user-only turns). Skip sessions
|
||
# with no user turns. Verbatim shape match with mempalace's
|
||
# benchmarks/longmemeval_bench.py reference loop.
|
||
for sess in sessions:
|
||
user_turns = [
|
||
str(turn.get("content", "")).strip()
|
||
for turn in sess.turns
|
||
if str(turn.get("role", "user")) == "user"
|
||
and str(turn.get("content", "")).strip()
|
||
]
|
||
if not user_turns:
|
||
continue
|
||
doc_text = "\n".join(user_turns)
|
||
vec = embedder.embed(doc_text)
|
||
rec = _make_record(
|
||
content=doc_text,
|
||
session_id=sess.session_id,
|
||
role="user",
|
||
embedding=vec,
|
||
)
|
||
store.insert(rec)
|
||
id_to_session[str(rec.id)] = sess.session_id
|
||
inserted_text_tokens += _count_tokens(doc_text)
|
||
else:
|
||
# Turn-granularity (v1/v2 baseline; bytes-identical loop body).
|
||
for sess in sessions:
|
||
for turn in sess.turns:
|
||
content = str(turn.get("content", "")).strip()
|
||
if not content:
|
||
continue
|
||
vec = embedder.embed(content)
|
||
rec = _make_record(
|
||
content=content,
|
||
session_id=sess.session_id,
|
||
role=str(turn.get("role", "user")),
|
||
embedding=vec,
|
||
)
|
||
store.insert(rec)
|
||
id_to_session[str(rec.id)] = sess.session_id
|
||
inserted_text_tokens += _count_tokens(content)
|
||
|
||
# Flush the async queue before recall. disable_async_writes is a
|
||
# coroutine too — drive from a fresh loop.
|
||
asyncio.run(store.disable_async_writes())
|
||
t_after_insert = time.time()
|
||
|
||
# --------- Build runtime graph (Plan 05-09 cache warms cold-start) ---------
|
||
# bench/lme500: capture the (graph, assignment, rich_club) tuple so
|
||
# recall_for_benchmark (prong Y) can reuse it. retrieve_recall (prong X)
|
||
# is unaffected by graph build success/failure.
|
||
graph = None
|
||
assignment = None
|
||
rich_club = None
|
||
try:
|
||
graph, assignment, rich_club = build_runtime_graph(store)
|
||
except Exception as exc: # pragma: no cover — cache helpers should be robust
|
||
# Don't fail the row on graph build; retrieve_recall is still
|
||
# callable from the flat store. recall_for_benchmark will be skipped
|
||
# for this row and counted as miss for the Y prong.
|
||
print(
|
||
f"[LME] row={row_id} build_runtime_graph failed: "
|
||
f"{type(exc).__name__}: {exc}",
|
||
file=sys.stderr,
|
||
)
|
||
t_after_graph = time.time()
|
||
|
||
# --------- Prong X: retrieve_recall (flat-cosine, baseline) ---------
|
||
cue_embedding = embedder.embed(question)
|
||
resp_x = retrieve_recall(
|
||
store=store,
|
||
cue_embedding=cue_embedding,
|
||
cue_text=question,
|
||
session_id=f"lme-{row_id}",
|
||
budget_tokens=1500,
|
||
k_hits=10,
|
||
k_anti=0,
|
||
)
|
||
t_after_x = time.time()
|
||
|
||
# --------- Prong Y: recall_for_benchmark (full graph-native architecture) ---------
|
||
# entry-point split: bench harness uses the top-K contract
|
||
# (k_hits=10, no budget_tokens). mode="concept" preserved verbatim — the
|
||
# bench is concept-shaped per BENCH_PROTOCOL_lme500.md and the D-02
|
||
# `_gate_bias_for_mode("concept") == 0.1` bias is what v2 measurements observe.
|
||
resp_y = None
|
||
pipeline_error: str | None = None
|
||
if graph is not None:
|
||
try:
|
||
resp_y = recall_for_benchmark(
|
||
store=store,
|
||
graph=graph,
|
||
assignment=assignment,
|
||
rich_club=rich_club,
|
||
embedder=embedder,
|
||
cue=question,
|
||
session_id=f"lme-{row_id}",
|
||
k_hits=10,
|
||
profile_state=None,
|
||
turn=0,
|
||
mode="concept",
|
||
)
|
||
except Exception as exc:
|
||
pipeline_error = f"{type(exc).__name__}: {str(exc)[:200]}"
|
||
print(
|
||
f"[LME] row={row_id} recall_for_benchmark failed: "
|
||
f"{pipeline_error}",
|
||
file=sys.stderr,
|
||
)
|
||
else:
|
||
pipeline_error = "graph_build_failed"
|
||
t_after_y = time.time()
|
||
|
||
def _retrieved_session_ids(resp) -> list[str]:
|
||
if resp is None:
|
||
return []
|
||
out: list[str] = []
|
||
for hit in resp.hits:
|
||
sid = id_to_session.get(str(hit.record_id))
|
||
if sid is not None:
|
||
out.append(sid)
|
||
return out
|
||
|
||
sids_x = _retrieved_session_ids(resp_x)
|
||
sids_y = _retrieved_session_ids(resp_y)
|
||
|
||
# LongMemEval-standard R@k at session-ID granularity: hit-at-k.
|
||
# R@k = 1.0 if any of the top-k retrieved records belongs to a gold
|
||
# session, else 0.0. Aggregated across rows by the caller.
|
||
def _hit_at_k(sids: list[str], k: int) -> float:
|
||
top = sids[:k]
|
||
return 1.0 if any(s in answer_session_ids for s in top) else 0.0
|
||
|
||
r5_x = _hit_at_k(sids_x, 5)
|
||
r10_x = _hit_at_k(sids_x, 10)
|
||
r5_y = _hit_at_k(sids_y, 5) if resp_y is not None else 0.0
|
||
r10_y = _hit_at_k(sids_y, 10) if resp_y is not None else 0.0
|
||
|
||
query_tokens = _count_tokens(question)
|
||
|
||
return {
|
||
"question_id": row_id,
|
||
"question_type": question_type,
|
||
# Prong X — retrieve_recall (flat-cosine baseline, line-by-line)
|
||
"r_at_5_retrieve": r5_x,
|
||
"r_at_10_retrieve": r10_x,
|
||
# Prong Y — recall_for_benchmark (full graph-native pipeline; D-07)
|
||
"r_at_5_pipeline": r5_y,
|
||
"r_at_10_pipeline": r10_y,
|
||
"pipeline_error": pipeline_error,
|
||
# Shared
|
||
"query_tokens": query_tokens,
|
||
"inserted_text_tokens": inserted_text_tokens,
|
||
"n_haystack_sessions": len(sessions),
|
||
"n_turns_inserted": len(id_to_session),
|
||
"timing_seconds": {
|
||
"insert": round(t_after_insert - t0, 2),
|
||
"graph": round(t_after_graph - t_after_insert, 2),
|
||
"recall_retrieve": round(t_after_x - t_after_graph, 2),
|
||
"recall_pipeline": round(t_after_y - t_after_x, 2),
|
||
"total": round(t_after_y - t0, 2),
|
||
},
|
||
}
|
||
|
||
|
||
def main(argv: list[str] | None = None) -> int:
|
||
parser = argparse.ArgumentParser(description=__doc__)
|
||
parser.add_argument(
|
||
"--split",
|
||
default="S",
|
||
choices=["S", "M", "oracle"],
|
||
help="LongMemEval split (Plan 05-11 runs S)",
|
||
)
|
||
parser.add_argument(
|
||
"--limit",
|
||
type=int,
|
||
default=None,
|
||
help=(
|
||
"practical-cap on rows evaluated. LongMemEval-S = 500 rows; "
|
||
"at ~500 turns/row and 11ms/embed on a 16GB M1 laptop, the "
|
||
"full 500-row run is multi-hour. --limit lets the blind pilot "
|
||
"finish; the SUMMARY discloses the cap honestly."
|
||
),
|
||
)
|
||
parser.add_argument(
|
||
"--out",
|
||
default="/tmp/p11_lme_full.json",
|
||
help="output JSON path",
|
||
)
|
||
parser.add_argument(
|
||
"--checkpoint",
|
||
default=None,
|
||
help=(
|
||
"JSONL checkpoint path for crash-resume; default = <out>.jsonl. "
|
||
"Each completed (or errored) row is appended with fsync as one "
|
||
"JSON line. On restart, rows whose question_id already appears "
|
||
"in the checkpoint are skipped."
|
||
),
|
||
)
|
||
# granularity flag with mempalace-aligned default.
|
||
parser.add_argument(
|
||
"--granularity",
|
||
choices=["session", "turn"],
|
||
default="session",
|
||
help=(
|
||
"corpus-construction granularity. "
|
||
"'session' (default, v3): one record per session, "
|
||
"content = '\\n'.join(user-only turns) — matches mempalace's "
|
||
"reference. 'turn': one record per turn (v1/v2 baseline; "
|
||
"use with --dataset raw to reproduce v2's 0.956)."
|
||
),
|
||
)
|
||
# dataset choice flag with mempalace-aligned default.
|
||
parser.add_argument(
|
||
"--dataset",
|
||
choices=["cleaned", "raw"],
|
||
default="cleaned",
|
||
help=(
|
||
"dataset variant. 'cleaned' (default, v3): "
|
||
"xiaowu0162/longmemeval-cleaned, SHA pinned via repo_info(). "
|
||
"'raw' (v1/v2 baseline): xiaowu0162/longmemeval rev "
|
||
"2ec2a557f339... — use with --granularity turn to reproduce "
|
||
"v2's 0.956."
|
||
),
|
||
)
|
||
# Step B: per-qid filter for the v2-baseline
|
||
# smoke reproducer. Applied AFTER --limit so a future caller passing
|
||
# both flags gets a deterministic intersection (limit narrows by row
|
||
# count, qid-include narrows by id). Default None preserves v1/v2 behaviour.
|
||
parser.add_argument(
|
||
"--qid-include",
|
||
default=None,
|
||
help=(
|
||
"comma-separated list of question_ids; if set, only these "
|
||
"rows run (used by smoke tests for per-qid baseline "
|
||
"verification). Applied after --limit."
|
||
),
|
||
)
|
||
# bench-only embedder swap. Default preserves v3
|
||
# baseline (bge-small-en-v1.5). all-MiniLM-L6-v2 is mempalace's ChromaDB
|
||
# default — used for the embedder-axis ablation in v3.1. Production
|
||
# embedder is unchanged regardless of this flag (English-Only Brain lock
|
||
# from / Plan 05-08; the Embedder.__init__ kwarg is the only
|
||
# entry point that surfaces the registry's all-MiniLM-L6-v2 entry).
|
||
parser.add_argument(
|
||
"--embedder",
|
||
choices=["bge-small-en-v1.5", "all-MiniLM-L6-v2"],
|
||
default="bge-small-en-v1.5",
|
||
help=(
|
||
"embedder model_key. 'bge-small-en-v1.5' (default, v3 "
|
||
"baseline) routes via the production English-only embedder. "
|
||
"'all-MiniLM-L6-v2' (Phase 9.1 ablation) is mempalace's "
|
||
"ChromaDB default — bench-only swap, production unchanged."
|
||
),
|
||
)
|
||
args = parser.parse_args(argv)
|
||
|
||
print(
|
||
f"[LME] blind run starting "
|
||
f"split={args.split} limit={args.limit} "
|
||
f"granularity={args.granularity} dataset={args.dataset} "
|
||
f"embedder={args.embedder} "
|
||
f"out={args.out}",
|
||
file=sys.stderr,
|
||
flush=True,
|
||
)
|
||
|
||
# branch the adapter on --dataset.
|
||
if args.dataset == "cleaned":
|
||
from bench.adapters.longmemeval_cleaned import (
|
||
CLEANED_DATASET_ID,
|
||
CleanedLongMemEvalAdapter,
|
||
)
|
||
adapter = CleanedLongMemEvalAdapter()
|
||
dataset_id_emit = CLEANED_DATASET_ID
|
||
revision_emit = adapter.revision
|
||
else:
|
||
adapter = LongMemEvalAdapter()
|
||
dataset_id_emit = DATASET_ID
|
||
revision_emit = PINNED_REVISION
|
||
# Adapter yields one LMESession per haystack session, but the
|
||
# blind-run protocol needs rows (one question + all its haystack
|
||
# sessions). Group by question_id (carried inside queries[0]).
|
||
grouped: dict[str, dict[str, Any]] = {}
|
||
row_order: list[str] = []
|
||
for lme_session in adapter.load_dataset(split=args.split):
|
||
q = lme_session.queries[0]
|
||
qid = q["question_id"]
|
||
if qid not in grouped:
|
||
grouped[qid] = {
|
||
"question": q["query"],
|
||
"question_type": q.get("question_type", "unknown"),
|
||
"answer_session_ids": set(q.get("relevant_turn_ids", [])),
|
||
"sessions": [],
|
||
}
|
||
row_order.append(qid)
|
||
grouped[qid]["sessions"].append(lme_session)
|
||
|
||
if args.limit is not None:
|
||
row_order = row_order[: args.limit]
|
||
|
||
# Step B: --qid-include filter applied AFTER
|
||
# --limit so a future caller passing both flags gets a deterministic
|
||
# intersection. The default None path is a no-op for backward compat.
|
||
if args.qid_include is not None:
|
||
wanted = {q.strip() for q in str(args.qid_include).split(",") if q.strip()}
|
||
row_order = [qid for qid in row_order if qid in wanted]
|
||
print(
|
||
f"[LME] qid-include filter: kept {len(row_order)} of "
|
||
f"{len(wanted)} requested qids",
|
||
file=sys.stderr,
|
||
flush=True,
|
||
)
|
||
|
||
tmp_root = Path(tempfile.mkdtemp(prefix="lme_blind_"))
|
||
print(f"[LME] per-row stores rooted at {tmp_root}", file=sys.stderr, flush=True)
|
||
|
||
per_row: list[dict[str, Any]] = []
|
||
errors: list[dict[str, str]] = []
|
||
# bench/lme500: track BOTH prongs (X = retrieve_recall, Y = recall_for_benchmark).
|
||
r5_x_values: list[float] = []
|
||
r10_x_values: list[float] = []
|
||
r5_y_values: list[float] = []
|
||
r10_y_values: list[float] = []
|
||
query_tokens: list[int] = []
|
||
session_tokens: list[int] = []
|
||
|
||
# bench/lme500: per-row JSONL checkpoint for crash resume.
|
||
# Each row's full result is appended with flush + fsync, so a kill at
|
||
# row N preserves rows 1..N-1 fully. Restart skips rows already in the
|
||
# checkpoint (matched by question_id).
|
||
checkpoint_path = Path(args.checkpoint) if args.checkpoint else Path(str(args.out) + ".jsonl")
|
||
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
||
completed_ids: set[str] = set()
|
||
if checkpoint_path.exists():
|
||
with open(checkpoint_path, "r", encoding="utf-8") as cp_f:
|
||
for line in cp_f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
rec = json.loads(line)
|
||
except json.JSONDecodeError:
|
||
print(
|
||
f"[LME] WARN: skipping corrupt checkpoint line: {line[:80]!r}",
|
||
file=sys.stderr,
|
||
flush=True,
|
||
)
|
||
continue
|
||
qid = rec.get("question_id")
|
||
if not qid:
|
||
continue
|
||
completed_ids.add(qid)
|
||
if "error" in rec and isinstance(rec.get("error"), dict):
|
||
# Resumed error row: count as full miss for both prongs.
|
||
errors.append(
|
||
{
|
||
"question_id": qid,
|
||
"error_class": rec["error"].get("error_class", "Unknown"),
|
||
"error": rec["error"].get("error", ""),
|
||
}
|
||
)
|
||
r5_x_values.append(0.0)
|
||
r10_x_values.append(0.0)
|
||
r5_y_values.append(0.0)
|
||
r10_y_values.append(0.0)
|
||
query_tokens.append(0)
|
||
session_tokens.append(0)
|
||
else:
|
||
# Resumed success row.
|
||
per_row.append(rec)
|
||
r5_x_values.append(float(rec.get("r_at_5_retrieve", 0.0)))
|
||
r10_x_values.append(float(rec.get("r_at_10_retrieve", 0.0)))
|
||
r5_y_values.append(float(rec.get("r_at_5_pipeline", 0.0)))
|
||
r10_y_values.append(float(rec.get("r_at_10_pipeline", 0.0)))
|
||
query_tokens.append(int(rec.get("query_tokens", 0)))
|
||
session_tokens.append(int(rec.get("inserted_text_tokens", 0)))
|
||
if completed_ids:
|
||
print(
|
||
f"[LME] resume: {len(completed_ids)} rows already in checkpoint "
|
||
f"{checkpoint_path}; processing {len(row_order) - len(completed_ids)} remaining",
|
||
file=sys.stderr,
|
||
flush=True,
|
||
)
|
||
else:
|
||
print(
|
||
f"[LME] checkpoint: writing per-row durable JSONL to {checkpoint_path}",
|
||
file=sys.stderr,
|
||
flush=True,
|
||
)
|
||
|
||
def _checkpoint_append(rec: dict[str, Any]) -> None:
|
||
"""Append one row record to the checkpoint, flush+fsync for durability."""
|
||
with open(checkpoint_path, "a", encoding="utf-8") as cp_a:
|
||
cp_a.write(json.dumps(rec) + "\n")
|
||
cp_a.flush()
|
||
os.fsync(cp_a.fileno())
|
||
|
||
run_t0 = time.time()
|
||
for i, qid in enumerate(row_order):
|
||
if qid in completed_ids:
|
||
continue
|
||
row = grouped[qid]
|
||
try:
|
||
res = _run_one_row(
|
||
row_id=qid,
|
||
question=row["question"],
|
||
question_type=row["question_type"],
|
||
answer_session_ids=row["answer_session_ids"],
|
||
sessions=row["sessions"],
|
||
tmp_root=tmp_root,
|
||
granularity=args.granularity,
|
||
embedder_key=args.embedder,
|
||
)
|
||
per_row.append(res)
|
||
r5_x_values.append(res["r_at_5_retrieve"])
|
||
r10_x_values.append(res["r_at_10_retrieve"])
|
||
r5_y_values.append(res["r_at_5_pipeline"])
|
||
r10_y_values.append(res["r_at_10_pipeline"])
|
||
query_tokens.append(res["query_tokens"])
|
||
session_tokens.append(res["inserted_text_tokens"])
|
||
_checkpoint_append(res)
|
||
elapsed = time.time() - run_t0
|
||
print(
|
||
f"[LME] row {i+1}/{len(row_order)} qid={qid} "
|
||
f"qtype={res['question_type']} "
|
||
f"R@5_x={res['r_at_5_retrieve']:.0f} R@5_y={res['r_at_5_pipeline']:.0f} "
|
||
f"R@10_x={res['r_at_10_retrieve']:.0f} R@10_y={res['r_at_10_pipeline']:.0f} "
|
||
f"t_row={res['timing_seconds']['total']:.1f}s "
|
||
f"t_total={elapsed:.1f}s",
|
||
file=sys.stderr,
|
||
flush=True,
|
||
)
|
||
except Exception as exc:
|
||
# T-05-11-04 mitigation: log + count as miss, do
|
||
# NOT silently drop.
|
||
err_payload = {
|
||
"error_class": type(exc).__name__,
|
||
"error": str(exc)[:500],
|
||
}
|
||
errors.append({"question_id": qid, **err_payload})
|
||
# Counted as a full miss for both prongs — preserves
|
||
# "count against R@5 as 0" from the plan text.
|
||
r5_x_values.append(0.0)
|
||
r10_x_values.append(0.0)
|
||
r5_y_values.append(0.0)
|
||
r10_y_values.append(0.0)
|
||
query_tokens.append(0)
|
||
session_tokens.append(0)
|
||
# Persist the error row to checkpoint so a restart skips it.
|
||
_checkpoint_append(
|
||
{
|
||
"question_id": qid,
|
||
"question_type": row.get("question_type", "unknown"),
|
||
"error": err_payload,
|
||
}
|
||
)
|
||
print(
|
||
f"[LME] ERROR row={qid}: {type(exc).__name__}: {exc}",
|
||
file=sys.stderr,
|
||
flush=True,
|
||
)
|
||
traceback.print_exc(file=sys.stderr)
|
||
finally:
|
||
# Free disk aggressively — many rows × ~500 turns per store
|
||
# adds up even on 64GB.
|
||
row_dir = tmp_root / f"row-{qid}"
|
||
if row_dir.exists():
|
||
shutil.rmtree(row_dir, ignore_errors=True)
|
||
|
||
shutil.rmtree(tmp_root, ignore_errors=True)
|
||
|
||
def _mean(xs: list[float]) -> float:
|
||
return (sum(xs) / len(xs)) if xs else 0.0
|
||
|
||
out = {
|
||
"split": args.split,
|
||
"dataset_id": dataset_id_emit,
|
||
"revision": revision_emit,
|
||
# reproducibility fields:
|
||
"granularity": args.granularity,
|
||
"dataset_choice": args.dataset,
|
||
# embedder identity pinned for v3.1 ablation reproducibility.
|
||
# Default "bge-small-en-v1.5" reproduces v3 baseline; "all-MiniLM-L6-v2"
|
||
# is the embedder-axis ablation toggle (mempalace ChromaDB default).
|
||
"embedder_model_key": args.embedder,
|
||
"embedder_hf_id": Embedder(model_key=args.embedder).model_name,
|
||
"n_rows": len(row_order),
|
||
# Prong X — retrieve_recall (flat-cosine baseline, line-by-line)
|
||
"r_at_5_retrieve": _mean(r5_x_values),
|
||
"r_at_10_retrieve": _mean(r10_x_values),
|
||
# Prong Y — recall_for_benchmark (full graph-native architecture; D-07)
|
||
"r_at_5_pipeline": _mean(r5_y_values),
|
||
"r_at_10_pipeline": _mean(r10_y_values),
|
||
# Architecture lift (Y - X)
|
||
"r_at_5_lift": _mean(r5_y_values) - _mean(r5_x_values),
|
||
"r_at_10_lift": _mean(r10_y_values) - _mean(r10_x_values),
|
||
"token_p50": _percentile(query_tokens, 50),
|
||
"token_p95": _percentile(query_tokens, 95),
|
||
"session_tokens_mean": (
|
||
statistics.fmean(session_tokens) if session_tokens else 0.0
|
||
),
|
||
"errors": errors,
|
||
"hard_limit": args.limit,
|
||
"metric_def": (
|
||
"Session-ID hit-at-k: R@k = 1.0 if any of top-k retrieved records "
|
||
"belongs to a gold session_id, else 0.0 (LongMemEval standard)."
|
||
),
|
||
"per_row": per_row,
|
||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||
"total_wall_seconds": round(time.time() - run_t0, 2),
|
||
}
|
||
|
||
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
|
||
with open(args.out, "w", encoding="utf-8") as f:
|
||
json.dump(out, f, indent=2)
|
||
|
||
print(
|
||
f"[LME] DONE n_rows={out['n_rows']} "
|
||
f"R@5_retrieve={out['r_at_5_retrieve']:.3f} "
|
||
f"R@5_pipeline={out['r_at_5_pipeline']:.3f} "
|
||
f"lift_R@5={out['r_at_5_lift']:+.3f} "
|
||
f"R@10_retrieve={out['r_at_10_retrieve']:.3f} "
|
||
f"R@10_pipeline={out['r_at_10_pipeline']:.3f} "
|
||
f"lift_R@10={out['r_at_10_lift']:+.3f} "
|
||
f"errors={len(errors)} -> {args.out}",
|
||
file=sys.stderr,
|
||
flush=True,
|
||
)
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|