341 lines
11 KiB
Python
341 lines
11 KiB
Python
|
|
"""Plan 05-12 — graph-native recall tests (RED scaffold).
|
||
|
|
|
||
|
|
Close the latency gap by switching recall_for_response's seed + spread
|
||
|
|
stages from per-id ``store.get(rid)`` LanceDB round-trips to in-RAM
|
||
|
|
``G.nodes[rid]`` attribute lookups. ``build_runtime_graph`` attaches the
|
||
|
|
record payload (embedding, surface, centrality, tier) to every graph
|
||
|
|
node so the recall hot path never touches disk for a graph-resident id.
|
||
|
|
|
||
|
|
Covered contracts:
|
||
|
|
|
||
|
|
A1 — every node in G carries embedding + surface + centrality + tier
|
||
|
|
after ``build_runtime_graph``.
|
||
|
|
A2 — seed stage does NOT call ``store.get`` (patch raises if invoked).
|
||
|
|
A3 — spread stage (rank/reachable walk) does NOT call ``store.get``.
|
||
|
|
A4 — verbatim L0 fast path (cue_text exact-match / gate skip) still
|
||
|
|
hits ``store.get`` — invariant path is untouched.
|
||
|
|
A5 — partial sync / missing attribute on a node falls back to
|
||
|
|
``store.get`` without crashing; recall still returns hits.
|
||
|
|
A6 — correctness fence: recall returns the seeded records with
|
||
|
|
high cosine similarity (no correctness regression).
|
||
|
|
"""
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from pathlib import Path
|
||
|
|
from unittest import mock
|
||
|
|
from uuid import uuid4
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from iai_mcp import retrieve
|
||
|
|
from iai_mcp.pipeline import recall_for_response
|
||
|
|
from iai_mcp.store import MemoryStore
|
||
|
|
from iai_mcp.types import MemoryRecord
|
||
|
|
|
||
|
|
|
||
|
|
# --------------------------------------------------------------------------- fixtures
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture(autouse=True)
|
||
|
|
def _isolated_keyring(monkeypatch: pytest.MonkeyPatch):
|
||
|
|
"""Swap macOS Keychain for an in-memory dict so tests don't prompt."""
|
||
|
|
import keyring as _keyring
|
||
|
|
|
||
|
|
fake: dict[tuple[str, str], str] = {}
|
||
|
|
monkeypatch.setattr(_keyring, "get_password", lambda s, u: fake.get((s, u)))
|
||
|
|
monkeypatch.setattr(
|
||
|
|
_keyring, "set_password", lambda s, u, p: fake.__setitem__((s, u), p)
|
||
|
|
)
|
||
|
|
monkeypatch.setattr(
|
||
|
|
_keyring, "delete_password", lambda s, u: fake.pop((s, u), None)
|
||
|
|
)
|
||
|
|
yield fake
|
||
|
|
|
||
|
|
|
||
|
|
class _DetEmbedder:
|
||
|
|
"""Deterministic embedder — seeds record vectors by text hash."""
|
||
|
|
|
||
|
|
def __init__(self, dim: int = 384) -> None:
|
||
|
|
self.DIM = dim
|
||
|
|
self.DEFAULT_DIM = dim
|
||
|
|
self.DEFAULT_MODEL_KEY = "test"
|
||
|
|
|
||
|
|
def embed(self, text: str) -> list[float]:
|
||
|
|
import hashlib
|
||
|
|
import random
|
||
|
|
|
||
|
|
digest = hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||
|
|
rng = random.Random(int(digest[:16], 16))
|
||
|
|
v = [rng.random() * 2 - 1 for _ in range(self.DIM)]
|
||
|
|
n = sum(x * x for x in v) ** 0.5
|
||
|
|
return [x / n for x in v] if n > 0 else v
|
||
|
|
|
||
|
|
|
||
|
|
def _make_record(vec: list[float], text: str) -> MemoryRecord:
|
||
|
|
now = datetime.now(timezone.utc)
|
||
|
|
return MemoryRecord(
|
||
|
|
id=uuid4(),
|
||
|
|
tier="episodic",
|
||
|
|
literal_surface=text,
|
||
|
|
aaak_index="",
|
||
|
|
embedding=vec,
|
||
|
|
community_id=None,
|
||
|
|
centrality=0.0,
|
||
|
|
detail_level=2,
|
||
|
|
pinned=False,
|
||
|
|
stability=0.0,
|
||
|
|
difficulty=0.0,
|
||
|
|
last_reviewed=None,
|
||
|
|
never_decay=False,
|
||
|
|
never_merge=False,
|
||
|
|
provenance=[],
|
||
|
|
created_at=now,
|
||
|
|
updated_at=now,
|
||
|
|
tags=["t"],
|
||
|
|
language="en",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def seeded_store(tmp_path: Path) -> tuple[MemoryStore, _DetEmbedder, list[MemoryRecord]]:
|
||
|
|
"""Fresh store with 12 records so the seed+spread stages have enough
|
||
|
|
material to exercise the graph-native read path."""
|
||
|
|
store = MemoryStore(path=tmp_path / "lancedb")
|
||
|
|
store.root = tmp_path
|
||
|
|
emb = _DetEmbedder(dim=store.embed_dim)
|
||
|
|
recs = []
|
||
|
|
for i in range(12):
|
||
|
|
vec = emb.embed(f"fact-{i}")
|
||
|
|
rec = _make_record(vec, f"synthetic fact {i}")
|
||
|
|
store.insert(rec)
|
||
|
|
recs.append(rec)
|
||
|
|
return store, emb, recs
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------- A1: node payload
|
||
|
|
|
||
|
|
|
||
|
|
def test_A1_build_runtime_graph_attaches_node_payload(seeded_store):
|
||
|
|
"""A1: every node carries embedding + surface + centrality + tier."""
|
||
|
|
store, _emb, recs = seeded_store
|
||
|
|
graph, _assignment, _rc = retrieve.build_runtime_graph(store)
|
||
|
|
|
||
|
|
# Use the underlying NetworkX graph directly; adds the
|
||
|
|
# payload as NetworkX node attributes via G.add_node(id, **payload).
|
||
|
|
G = graph._nx
|
||
|
|
assert G.number_of_nodes() == len(recs)
|
||
|
|
for rec in recs:
|
||
|
|
node = G.nodes[str(rec.id)]
|
||
|
|
assert "embedding" in node, f"node {rec.id} missing embedding attr"
|
||
|
|
assert "surface" in node, f"node {rec.id} missing surface attr"
|
||
|
|
assert "centrality" in node, f"node {rec.id} missing centrality attr"
|
||
|
|
assert "tier" in node, f"node {rec.id} missing tier attr"
|
||
|
|
# Embedding list matches the record's embedding.
|
||
|
|
assert list(node["embedding"]) == list(rec.embedding)
|
||
|
|
assert node["surface"] == rec.literal_surface
|
||
|
|
assert node["tier"] == rec.tier
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------- A2: seed stage
|
||
|
|
|
||
|
|
|
||
|
|
def test_A2_seed_stage_reads_from_graph_not_store(seeded_store):
|
||
|
|
"""A2: seed stage (top-K by cosine) must NOT call store.get.
|
||
|
|
|
||
|
|
We patch MemoryStore.get to raise; if recall_for_response still returns
|
||
|
|
a non-empty RecallResponse, the seed stage is graph-native.
|
||
|
|
"""
|
||
|
|
store, emb, _recs = seeded_store
|
||
|
|
graph, assignment, rich_club = retrieve.build_runtime_graph(store)
|
||
|
|
|
||
|
|
# The verbatim L0 fast-path (gate skip) calls store.get too — disable
|
||
|
|
# the skip by choosing a cue that the gate will NOT classify as trivial.
|
||
|
|
cue = "explain the authentication migration for long-running deployments"
|
||
|
|
|
||
|
|
# AllowedError raises ONLY on the hot-path store.get; the L0 fast-path
|
||
|
|
# is known not to fire for this cue.
|
||
|
|
class _Boom(RuntimeError):
|
||
|
|
pass
|
||
|
|
|
||
|
|
original_get = store.get
|
||
|
|
|
||
|
|
def _explode(rid):
|
||
|
|
# Allow the verbatim L0 UUID fetch to pass through so the fast-path
|
||
|
|
# check (no L0 record seeded) is a clean miss — but any OTHER store.get
|
||
|
|
# call blows up.
|
||
|
|
from uuid import UUID
|
||
|
|
l0 = UUID("00000000-0000-0000-0000-000000000001")
|
||
|
|
if rid == l0:
|
||
|
|
return None
|
||
|
|
raise _Boom(f"store.get({rid}) — seed stage should not call this")
|
||
|
|
|
||
|
|
with mock.patch.object(MemoryStore, "get", side_effect=_explode):
|
||
|
|
resp = recall_for_response(
|
||
|
|
store=store,
|
||
|
|
graph=graph,
|
||
|
|
assignment=assignment,
|
||
|
|
rich_club=rich_club,
|
||
|
|
embedder=emb,
|
||
|
|
cue=cue,
|
||
|
|
session_id="s",
|
||
|
|
budget_tokens=1500,
|
||
|
|
)
|
||
|
|
assert len(resp.hits) >= 1
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------- A3: spread stage
|
||
|
|
|
||
|
|
|
||
|
|
def test_A3_spread_stage_reads_from_graph_not_store(seeded_store):
|
||
|
|
"""A3: rank+spread stages do NOT call store.get either.
|
||
|
|
|
||
|
|
Same shape as A2 but asserts over the full reachable-union not just
|
||
|
|
seeds.
|
||
|
|
"""
|
||
|
|
store, emb, _recs = seeded_store
|
||
|
|
graph, assignment, rich_club = retrieve.build_runtime_graph(store)
|
||
|
|
|
||
|
|
cue = "network stack changes for the web cache"
|
||
|
|
|
||
|
|
class _Boom(RuntimeError):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def _explode(rid):
|
||
|
|
from uuid import UUID
|
||
|
|
l0 = UUID("00000000-0000-0000-0000-000000000001")
|
||
|
|
if rid == l0:
|
||
|
|
return None
|
||
|
|
raise _Boom(f"store.get({rid}) during spread/rank")
|
||
|
|
|
||
|
|
with mock.patch.object(MemoryStore, "get", side_effect=_explode):
|
||
|
|
resp = recall_for_response(
|
||
|
|
store=store,
|
||
|
|
graph=graph,
|
||
|
|
assignment=assignment,
|
||
|
|
rich_club=rich_club,
|
||
|
|
embedder=emb,
|
||
|
|
cue=cue,
|
||
|
|
session_id="s",
|
||
|
|
budget_tokens=1500,
|
||
|
|
)
|
||
|
|
# If spread/rank was using store.get, we would have exploded above.
|
||
|
|
assert isinstance(resp.hits, list)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------- A4: L0 fast path
|
||
|
|
|
||
|
|
|
||
|
|
def test_A4_verbatim_l0_fast_path_still_calls_store_get(seeded_store):
|
||
|
|
"""A4: the L0 (gate-skip) fast path still hits store.get — unchanged.
|
||
|
|
|
||
|
|
invariant: verbatim recall path is NOT touched.
|
||
|
|
"""
|
||
|
|
store, emb, _recs = seeded_store
|
||
|
|
# Seed the deterministic L0 record so the gate-skip branch fires.
|
||
|
|
from uuid import UUID
|
||
|
|
l0_id = UUID("00000000-0000-0000-0000-000000000001")
|
||
|
|
l0_vec = emb.embed("l0-identity")
|
||
|
|
now = datetime.now(timezone.utc)
|
||
|
|
l0_rec = MemoryRecord(
|
||
|
|
id=l0_id,
|
||
|
|
tier="semantic",
|
||
|
|
literal_surface="L0 identity kernel",
|
||
|
|
aaak_index="",
|
||
|
|
embedding=l0_vec,
|
||
|
|
community_id=None,
|
||
|
|
centrality=0.0,
|
||
|
|
detail_level=5, # never_decay
|
||
|
|
pinned=True,
|
||
|
|
stability=0.0,
|
||
|
|
difficulty=0.0,
|
||
|
|
last_reviewed=None,
|
||
|
|
never_decay=True,
|
||
|
|
never_merge=True,
|
||
|
|
provenance=[],
|
||
|
|
created_at=now,
|
||
|
|
updated_at=now,
|
||
|
|
tags=["identity"],
|
||
|
|
language="en",
|
||
|
|
)
|
||
|
|
store.insert(l0_rec)
|
||
|
|
graph, assignment, rich_club = retrieve.build_runtime_graph(store)
|
||
|
|
|
||
|
|
# Pick a cue that the gate treats as trivial (short / who-am-i style).
|
||
|
|
cue = "hi"
|
||
|
|
|
||
|
|
with mock.patch.object(MemoryStore, "get", wraps=store.get) as spy:
|
||
|
|
_ = recall_for_response(
|
||
|
|
store=store,
|
||
|
|
graph=graph,
|
||
|
|
assignment=assignment,
|
||
|
|
rich_club=rich_club,
|
||
|
|
embedder=emb,
|
||
|
|
cue=cue,
|
||
|
|
session_id="s",
|
||
|
|
budget_tokens=1500,
|
||
|
|
)
|
||
|
|
# At LEAST one store.get call on the L0 fast path (verbatim invariant).
|
||
|
|
assert spy.call_count >= 1
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------- A5: fallback
|
||
|
|
|
||
|
|
|
||
|
|
def test_A5_missing_node_attr_falls_back_to_store_get(seeded_store):
|
||
|
|
"""A5: if a node somehow lacks the embedding attr (race / partial
|
||
|
|
sync), _read_record_payload falls back to store.get and recall still
|
||
|
|
returns correct hits — no crash."""
|
||
|
|
store, emb, recs = seeded_store
|
||
|
|
graph, assignment, rich_club = retrieve.build_runtime_graph(store)
|
||
|
|
# Blow away the embedding attr on half the nodes.
|
||
|
|
G = graph._nx
|
||
|
|
victims = [str(r.id) for r in recs[:6]]
|
||
|
|
for nid in victims:
|
||
|
|
if "embedding" in G.nodes[nid]:
|
||
|
|
del G.nodes[nid]["embedding"]
|
||
|
|
|
||
|
|
cue = "summary of cli subcommand changes for the auth token rotation"
|
||
|
|
resp = recall_for_response(
|
||
|
|
store=store,
|
||
|
|
graph=graph,
|
||
|
|
assignment=assignment,
|
||
|
|
rich_club=rich_club,
|
||
|
|
embedder=emb,
|
||
|
|
cue=cue,
|
||
|
|
session_id="s",
|
||
|
|
budget_tokens=1500,
|
||
|
|
)
|
||
|
|
assert len(resp.hits) >= 1
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------- A6: correctness
|
||
|
|
|
||
|
|
|
||
|
|
def test_A6_m04_correctness_no_regression(seeded_store):
|
||
|
|
"""A6: recall returns the seeded record whose text matches the cue.
|
||
|
|
|
||
|
|
Minimal correctness fence inside this file (the heavyweight
|
||
|
|
bench.verbatim sweep covers gap=5/20/100 elsewhere; this guards the
|
||
|
|
happy-path-does-not-regress invariant inside the unit suite).
|
||
|
|
"""
|
||
|
|
store, emb, recs = seeded_store
|
||
|
|
graph, assignment, rich_club = retrieve.build_runtime_graph(store)
|
||
|
|
|
||
|
|
# Query with text similar to record 7 — its cosine should dominate.
|
||
|
|
resp = recall_for_response(
|
||
|
|
store=store,
|
||
|
|
graph=graph,
|
||
|
|
assignment=assignment,
|
||
|
|
rich_club=rich_club,
|
||
|
|
embedder=emb,
|
||
|
|
cue="synthetic fact 7",
|
||
|
|
session_id="s",
|
||
|
|
budget_tokens=1500,
|
||
|
|
)
|
||
|
|
# At least one hit comes back.
|
||
|
|
assert len(resp.hits) >= 1
|
||
|
|
# All hit record ids are in the seeded record id set.
|
||
|
|
seeded_ids = {r.id for r in recs}
|
||
|
|
assert all(h.record_id in seeded_ids for h in resp.hits)
|