Initial release: iai-mcp v0.1.0
Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: XNLLLLH <XNLLLLH@users.noreply.github.com>
This commit is contained in:
commit
f6b876fbe7
332 changed files with 97258 additions and 0 deletions
340
tests/test_graph_native_recall.py
Normal file
340
tests/test_graph_native_recall.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
"""Plan 05-12 — graph-native recall tests (RED scaffold).
|
||||
|
||||
Close the latency gap by switching recall_for_response's seed + spread
|
||||
stages from per-id ``store.get(rid)`` LanceDB round-trips to in-RAM
|
||||
``G.nodes[rid]`` attribute lookups. ``build_runtime_graph`` attaches the
|
||||
record payload (embedding, surface, centrality, tier) to every graph
|
||||
node so the recall hot path never touches disk for a graph-resident id.
|
||||
|
||||
Covered contracts:
|
||||
|
||||
A1 — every node in G carries embedding + surface + centrality + tier
|
||||
after ``build_runtime_graph``.
|
||||
A2 — seed stage does NOT call ``store.get`` (patch raises if invoked).
|
||||
A3 — spread stage (rank/reachable walk) does NOT call ``store.get``.
|
||||
A4 — verbatim L0 fast path (cue_text exact-match / gate skip) still
|
||||
hits ``store.get`` — invariant path is untouched.
|
||||
A5 — partial sync / missing attribute on a node falls back to
|
||||
``store.get`` without crashing; recall still returns hits.
|
||||
A6 — correctness fence: recall returns the seeded records with
|
||||
high cosine similarity (no correctness regression).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from iai_mcp import retrieve
|
||||
from iai_mcp.pipeline import recall_for_response
|
||||
from iai_mcp.store import MemoryStore
|
||||
from iai_mcp.types import MemoryRecord
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- fixtures
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolated_keyring(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Swap macOS Keychain for an in-memory dict so tests don't prompt."""
|
||||
import keyring as _keyring
|
||||
|
||||
fake: dict[tuple[str, str], str] = {}
|
||||
monkeypatch.setattr(_keyring, "get_password", lambda s, u: fake.get((s, u)))
|
||||
monkeypatch.setattr(
|
||||
_keyring, "set_password", lambda s, u, p: fake.__setitem__((s, u), p)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
_keyring, "delete_password", lambda s, u: fake.pop((s, u), None)
|
||||
)
|
||||
yield fake
|
||||
|
||||
|
||||
class _DetEmbedder:
|
||||
"""Deterministic embedder — seeds record vectors by text hash."""
|
||||
|
||||
def __init__(self, dim: int = 384) -> None:
|
||||
self.DIM = dim
|
||||
self.DEFAULT_DIM = dim
|
||||
self.DEFAULT_MODEL_KEY = "test"
|
||||
|
||||
def embed(self, text: str) -> list[float]:
|
||||
import hashlib
|
||||
import random
|
||||
|
||||
digest = hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
rng = random.Random(int(digest[:16], 16))
|
||||
v = [rng.random() * 2 - 1 for _ in range(self.DIM)]
|
||||
n = sum(x * x for x in v) ** 0.5
|
||||
return [x / n for x in v] if n > 0 else v
|
||||
|
||||
|
||||
def _make_record(vec: list[float], text: str) -> MemoryRecord:
|
||||
now = datetime.now(timezone.utc)
|
||||
return MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier="episodic",
|
||||
literal_surface=text,
|
||||
aaak_index="",
|
||||
embedding=vec,
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=2,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=False,
|
||||
never_merge=False,
|
||||
provenance=[],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=["t"],
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seeded_store(tmp_path: Path) -> tuple[MemoryStore, _DetEmbedder, list[MemoryRecord]]:
|
||||
"""Fresh store with 12 records so the seed+spread stages have enough
|
||||
material to exercise the graph-native read path."""
|
||||
store = MemoryStore(path=tmp_path / "lancedb")
|
||||
store.root = tmp_path
|
||||
emb = _DetEmbedder(dim=store.embed_dim)
|
||||
recs = []
|
||||
for i in range(12):
|
||||
vec = emb.embed(f"fact-{i}")
|
||||
rec = _make_record(vec, f"synthetic fact {i}")
|
||||
store.insert(rec)
|
||||
recs.append(rec)
|
||||
return store, emb, recs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- A1: node payload
|
||||
|
||||
|
||||
def test_A1_build_runtime_graph_attaches_node_payload(seeded_store):
|
||||
"""A1: every node carries embedding + surface + centrality + tier."""
|
||||
store, _emb, recs = seeded_store
|
||||
graph, _assignment, _rc = retrieve.build_runtime_graph(store)
|
||||
|
||||
# Use the underlying NetworkX graph directly; adds the
|
||||
# payload as NetworkX node attributes via G.add_node(id, **payload).
|
||||
G = graph._nx
|
||||
assert G.number_of_nodes() == len(recs)
|
||||
for rec in recs:
|
||||
node = G.nodes[str(rec.id)]
|
||||
assert "embedding" in node, f"node {rec.id} missing embedding attr"
|
||||
assert "surface" in node, f"node {rec.id} missing surface attr"
|
||||
assert "centrality" in node, f"node {rec.id} missing centrality attr"
|
||||
assert "tier" in node, f"node {rec.id} missing tier attr"
|
||||
# Embedding list matches the record's embedding.
|
||||
assert list(node["embedding"]) == list(rec.embedding)
|
||||
assert node["surface"] == rec.literal_surface
|
||||
assert node["tier"] == rec.tier
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- A2: seed stage
|
||||
|
||||
|
||||
def test_A2_seed_stage_reads_from_graph_not_store(seeded_store):
|
||||
"""A2: seed stage (top-K by cosine) must NOT call store.get.
|
||||
|
||||
We patch MemoryStore.get to raise; if recall_for_response still returns
|
||||
a non-empty RecallResponse, the seed stage is graph-native.
|
||||
"""
|
||||
store, emb, _recs = seeded_store
|
||||
graph, assignment, rich_club = retrieve.build_runtime_graph(store)
|
||||
|
||||
# The verbatim L0 fast-path (gate skip) calls store.get too — disable
|
||||
# the skip by choosing a cue that the gate will NOT classify as trivial.
|
||||
cue = "explain the authentication migration for long-running deployments"
|
||||
|
||||
# AllowedError raises ONLY on the hot-path store.get; the L0 fast-path
|
||||
# is known not to fire for this cue.
|
||||
class _Boom(RuntimeError):
|
||||
pass
|
||||
|
||||
original_get = store.get
|
||||
|
||||
def _explode(rid):
|
||||
# Allow the verbatim L0 UUID fetch to pass through so the fast-path
|
||||
# check (no L0 record seeded) is a clean miss — but any OTHER store.get
|
||||
# call blows up.
|
||||
from uuid import UUID
|
||||
l0 = UUID("00000000-0000-0000-0000-000000000001")
|
||||
if rid == l0:
|
||||
return None
|
||||
raise _Boom(f"store.get({rid}) — seed stage should not call this")
|
||||
|
||||
with mock.patch.object(MemoryStore, "get", side_effect=_explode):
|
||||
resp = recall_for_response(
|
||||
store=store,
|
||||
graph=graph,
|
||||
assignment=assignment,
|
||||
rich_club=rich_club,
|
||||
embedder=emb,
|
||||
cue=cue,
|
||||
session_id="s",
|
||||
budget_tokens=1500,
|
||||
)
|
||||
assert len(resp.hits) >= 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- A3: spread stage
|
||||
|
||||
|
||||
def test_A3_spread_stage_reads_from_graph_not_store(seeded_store):
|
||||
"""A3: rank+spread stages do NOT call store.get either.
|
||||
|
||||
Same shape as A2 but asserts over the full reachable-union not just
|
||||
seeds.
|
||||
"""
|
||||
store, emb, _recs = seeded_store
|
||||
graph, assignment, rich_club = retrieve.build_runtime_graph(store)
|
||||
|
||||
cue = "network stack changes for the web cache"
|
||||
|
||||
class _Boom(RuntimeError):
|
||||
pass
|
||||
|
||||
def _explode(rid):
|
||||
from uuid import UUID
|
||||
l0 = UUID("00000000-0000-0000-0000-000000000001")
|
||||
if rid == l0:
|
||||
return None
|
||||
raise _Boom(f"store.get({rid}) during spread/rank")
|
||||
|
||||
with mock.patch.object(MemoryStore, "get", side_effect=_explode):
|
||||
resp = recall_for_response(
|
||||
store=store,
|
||||
graph=graph,
|
||||
assignment=assignment,
|
||||
rich_club=rich_club,
|
||||
embedder=emb,
|
||||
cue=cue,
|
||||
session_id="s",
|
||||
budget_tokens=1500,
|
||||
)
|
||||
# If spread/rank was using store.get, we would have exploded above.
|
||||
assert isinstance(resp.hits, list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- A4: L0 fast path
|
||||
|
||||
|
||||
def test_A4_verbatim_l0_fast_path_still_calls_store_get(seeded_store):
|
||||
"""A4: the L0 (gate-skip) fast path still hits store.get — unchanged.
|
||||
|
||||
invariant: verbatim recall path is NOT touched.
|
||||
"""
|
||||
store, emb, _recs = seeded_store
|
||||
# Seed the deterministic L0 record so the gate-skip branch fires.
|
||||
from uuid import UUID
|
||||
l0_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
l0_vec = emb.embed("l0-identity")
|
||||
now = datetime.now(timezone.utc)
|
||||
l0_rec = MemoryRecord(
|
||||
id=l0_id,
|
||||
tier="semantic",
|
||||
literal_surface="L0 identity kernel",
|
||||
aaak_index="",
|
||||
embedding=l0_vec,
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=5, # never_decay
|
||||
pinned=True,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=True,
|
||||
never_merge=True,
|
||||
provenance=[],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=["identity"],
|
||||
language="en",
|
||||
)
|
||||
store.insert(l0_rec)
|
||||
graph, assignment, rich_club = retrieve.build_runtime_graph(store)
|
||||
|
||||
# Pick a cue that the gate treats as trivial (short / who-am-i style).
|
||||
cue = "hi"
|
||||
|
||||
with mock.patch.object(MemoryStore, "get", wraps=store.get) as spy:
|
||||
_ = recall_for_response(
|
||||
store=store,
|
||||
graph=graph,
|
||||
assignment=assignment,
|
||||
rich_club=rich_club,
|
||||
embedder=emb,
|
||||
cue=cue,
|
||||
session_id="s",
|
||||
budget_tokens=1500,
|
||||
)
|
||||
# At LEAST one store.get call on the L0 fast path (verbatim invariant).
|
||||
assert spy.call_count >= 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- A5: fallback
|
||||
|
||||
|
||||
def test_A5_missing_node_attr_falls_back_to_store_get(seeded_store):
|
||||
"""A5: if a node somehow lacks the embedding attr (race / partial
|
||||
sync), _read_record_payload falls back to store.get and recall still
|
||||
returns correct hits — no crash."""
|
||||
store, emb, recs = seeded_store
|
||||
graph, assignment, rich_club = retrieve.build_runtime_graph(store)
|
||||
# Blow away the embedding attr on half the nodes.
|
||||
G = graph._nx
|
||||
victims = [str(r.id) for r in recs[:6]]
|
||||
for nid in victims:
|
||||
if "embedding" in G.nodes[nid]:
|
||||
del G.nodes[nid]["embedding"]
|
||||
|
||||
cue = "summary of cli subcommand changes for the auth token rotation"
|
||||
resp = recall_for_response(
|
||||
store=store,
|
||||
graph=graph,
|
||||
assignment=assignment,
|
||||
rich_club=rich_club,
|
||||
embedder=emb,
|
||||
cue=cue,
|
||||
session_id="s",
|
||||
budget_tokens=1500,
|
||||
)
|
||||
assert len(resp.hits) >= 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- A6: correctness
|
||||
|
||||
|
||||
def test_A6_m04_correctness_no_regression(seeded_store):
|
||||
"""A6: recall returns the seeded record whose text matches the cue.
|
||||
|
||||
Minimal correctness fence inside this file (the heavyweight
|
||||
bench.verbatim sweep covers gap=5/20/100 elsewhere; this guards the
|
||||
happy-path-does-not-regress invariant inside the unit suite).
|
||||
"""
|
||||
store, emb, recs = seeded_store
|
||||
graph, assignment, rich_club = retrieve.build_runtime_graph(store)
|
||||
|
||||
# Query with text similar to record 7 — its cosine should dominate.
|
||||
resp = recall_for_response(
|
||||
store=store,
|
||||
graph=graph,
|
||||
assignment=assignment,
|
||||
rich_club=rich_club,
|
||||
embedder=emb,
|
||||
cue="synthetic fact 7",
|
||||
session_id="s",
|
||||
budget_tokens=1500,
|
||||
)
|
||||
# At least one hit comes back.
|
||||
assert len(resp.hits) >= 1
|
||||
# All hit record ids are in the seeded record id set.
|
||||
seeded_ids = {r.id for r in recs}
|
||||
assert all(h.record_id in seeded_ids for h in resp.hits)
|
||||
Loading…
Add table
Add a link
Reference in a new issue