Initial release: iai-mcp v0.1.0
Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: XNLLLLH <XNLLLLH@users.noreply.github.com>
This commit is contained in:
commit
f6b876fbe7
332 changed files with 97258 additions and 0 deletions
211
tests/test_s4_batch_api.py
Normal file
211
tests/test_s4_batch_api.py
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
"""Tests for s4.on_read_check_batch (Plan 02-07 D-SPEED gap closure).
|
||||
|
||||
D-SPEED contract: bench/neural_map p95<100ms at N=100. Root cause:
|
||||
`s4.on_read_check` called per-hit inside pipeline_recall with no records_cache,
|
||||
forcing N+1 store.get() round-trips. Fix: new `on_read_check_batch` that accepts
|
||||
an optional records_cache from the caller and does ONE store.all_records() (or
|
||||
zero if cache provided).
|
||||
|
||||
Equivalence contract: on_read_check_batch returns semantically identical hint
|
||||
output to on_read_check for the same (store, hits, session_id) input. The
|
||||
source_id contents of the returned hints must be a set-equal match; orderings
|
||||
may differ because event-write side effects are intermingled.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from iai_mcp.types import EMBED_DIM, MemoryHit, MemoryRecord
|
||||
|
||||
|
||||
def _make_record(
|
||||
*,
|
||||
text: str = "hello",
|
||||
vec: list[float] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
detail_level: int = 2,
|
||||
tier: str = "episodic",
|
||||
language: str = "en",
|
||||
) -> MemoryRecord:
|
||||
if vec is None:
|
||||
vec = [1.0] + [0.0] * (EMBED_DIM - 1)
|
||||
now = datetime.now(timezone.utc)
|
||||
return MemoryRecord(
|
||||
id=uuid4(),
|
||||
tier=tier,
|
||||
literal_surface=text,
|
||||
aaak_index="",
|
||||
embedding=vec,
|
||||
community_id=None,
|
||||
centrality=0.0,
|
||||
detail_level=detail_level,
|
||||
pinned=False,
|
||||
stability=0.0,
|
||||
difficulty=0.0,
|
||||
last_reviewed=None,
|
||||
never_decay=False,
|
||||
never_merge=False,
|
||||
provenance=[],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
tags=list(tags or []),
|
||||
language=language,
|
||||
)
|
||||
|
||||
|
||||
def _hit_for(rec: MemoryRecord, score: float = 0.9) -> MemoryHit:
|
||||
return MemoryHit(
|
||||
record_id=rec.id,
|
||||
score=score,
|
||||
reason="test",
|
||||
literal_surface=rec.literal_surface,
|
||||
adjacent_suggestions=[],
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------- contract
|
||||
|
||||
|
||||
def test_s4_exports_on_read_check_batch():
|
||||
"""The batch variant exists and is callable."""
|
||||
from iai_mcp import s4
|
||||
|
||||
assert hasattr(s4, "on_read_check_batch")
|
||||
assert callable(s4.on_read_check_batch)
|
||||
|
||||
|
||||
# ------------------------------------------------------------- behaviour
|
||||
|
||||
|
||||
def test_on_read_check_batch_uses_records_cache(tmp_path):
|
||||
"""When records_cache is passed, store.get is NOT called (zero round-trips).
|
||||
|
||||
This is the core D-SPEED fix: the caller (pipeline_recall) builds
|
||||
records_cache at stage 1, so S4 must not re-fetch via store.get.
|
||||
Monkeypatch store.get to raise; the call MUST succeed without exception.
|
||||
"""
|
||||
from iai_mcp.s4 import on_read_check_batch
|
||||
from iai_mcp.store import MemoryStore
|
||||
|
||||
store = MemoryStore(path=tmp_path)
|
||||
v1 = [0.0] * EMBED_DIM; v1[0] = 1.0
|
||||
v2 = [0.0] * EMBED_DIM; v2[1] = 1.0
|
||||
r1 = _make_record(text="X is true", vec=v1, tags=["claim"])
|
||||
r2 = _make_record(text="X is false", vec=v2, tags=["claim"])
|
||||
store.insert(r1)
|
||||
store.insert(r2)
|
||||
store.add_contradicts_edge(r1.id, r2.id)
|
||||
|
||||
records_cache = {r1.id: r1, r2.id: r2}
|
||||
hits = [_hit_for(r1), _hit_for(r2)]
|
||||
|
||||
# If store.get is invoked at all, this test will raise.
|
||||
def _boom(*args, **kwargs):
|
||||
raise RuntimeError("store.get must not be called when records_cache is provided")
|
||||
|
||||
original_get = store.get
|
||||
store.get = _boom # type: ignore[assignment]
|
||||
try:
|
||||
result = on_read_check_batch(
|
||||
store, hits, session_id="test", records_cache=records_cache,
|
||||
)
|
||||
finally:
|
||||
store.get = original_get # type: ignore[assignment]
|
||||
|
||||
# Contradicts-edge detection still fires.
|
||||
assert len(result) == 1
|
||||
assert set(result[0]["source_ids"]) == {str(r1.id), str(r2.id)}
|
||||
|
||||
|
||||
def test_on_read_check_batch_fallback_no_cache(tmp_path):
|
||||
"""Without records_cache, falls back to exactly one store.all_records() call.
|
||||
|
||||
Counts invocations via monkeypatched counter. store.get must not be called;
|
||||
all_records must be called exactly once.
|
||||
"""
|
||||
from iai_mcp.s4 import on_read_check_batch
|
||||
from iai_mcp.store import MemoryStore
|
||||
|
||||
store = MemoryStore(path=tmp_path)
|
||||
v1 = [0.0] * EMBED_DIM; v1[0] = 1.0
|
||||
v2 = [0.0] * EMBED_DIM; v2[1] = 1.0
|
||||
r1 = _make_record(text="x", vec=v1)
|
||||
r2 = _make_record(text="y", vec=v2)
|
||||
store.insert(r1)
|
||||
store.insert(r2)
|
||||
|
||||
get_calls = [0]
|
||||
all_calls = [0]
|
||||
original_get = store.get
|
||||
original_all = store.all_records
|
||||
|
||||
def _counting_get(*a, **kw):
|
||||
get_calls[0] += 1
|
||||
return original_get(*a, **kw)
|
||||
|
||||
def _counting_all(*a, **kw):
|
||||
all_calls[0] += 1
|
||||
return original_all(*a, **kw)
|
||||
|
||||
store.get = _counting_get # type: ignore[assignment]
|
||||
store.all_records = _counting_all # type: ignore[assignment]
|
||||
try:
|
||||
hits = [_hit_for(r1), _hit_for(r2)]
|
||||
_ = on_read_check_batch(store, hits, session_id="test")
|
||||
finally:
|
||||
store.get = original_get # type: ignore[assignment]
|
||||
store.all_records = original_all # type: ignore[assignment]
|
||||
|
||||
assert get_calls[0] == 0, f"store.get called {get_calls[0]} times (should be 0)"
|
||||
assert all_calls[0] == 1, f"store.all_records called {all_calls[0]} times (should be 1)"
|
||||
|
||||
|
||||
def test_batch_api_equivalence_on_detection(tmp_path):
|
||||
"""on_read_check and on_read_check_batch return semantically-identical
|
||||
hint output over the same (store, hits, session_id) input.
|
||||
|
||||
Comparison is over the (kind, frozenset(source_ids)) pair so that event
|
||||
ordering / text wording differences don't invalidate parity.
|
||||
"""
|
||||
from iai_mcp.s4 import on_read_check, on_read_check_batch
|
||||
from iai_mcp.store import MemoryStore
|
||||
|
||||
store = MemoryStore(path=tmp_path)
|
||||
# Near-identical vectors + opposite polarity tags (cosine > 0.97)
|
||||
v1 = [1.0] + [0.0] * (EMBED_DIM - 1)
|
||||
v2 = [0.99] + [0.01] + [0.0] * (EMBED_DIM - 2)
|
||||
r1 = _make_record(text="X good", vec=v1, tags=["topic", "positive"])
|
||||
r2 = _make_record(text="X bad", vec=v2, tags=["topic", "negative"])
|
||||
# Additionally a contradicts pair
|
||||
v3 = [0.0] * EMBED_DIM; v3[2] = 1.0
|
||||
v4 = [0.0] * EMBED_DIM; v4[3] = 1.0
|
||||
r3 = _make_record(text="Y true", vec=v3, tags=["claim"])
|
||||
r4 = _make_record(text="Y false", vec=v4, tags=["claim"])
|
||||
for r in (r1, r2, r3, r4):
|
||||
store.insert(r)
|
||||
store.add_contradicts_edge(r3.id, r4.id)
|
||||
|
||||
hits = [_hit_for(r) for r in (r1, r2, r3, r4)]
|
||||
|
||||
single = on_read_check(store, hits, session_id="eq_test")
|
||||
batch = on_read_check_batch(store, hits, session_id="eq_test")
|
||||
|
||||
def _key(h: dict) -> tuple[str, frozenset[str]]:
|
||||
return (h["kind"], frozenset(h["source_ids"]))
|
||||
|
||||
assert {_key(h) for h in single} == {_key(h) for h in batch}
|
||||
# Both should have detected at least 2 hints: polarity + contradicts.
|
||||
assert len(batch) >= 2
|
||||
|
||||
|
||||
def test_on_read_check_batch_empty_hits(tmp_path):
|
||||
"""Empty hits list -> empty hints, no exception."""
|
||||
from iai_mcp.s4 import on_read_check_batch
|
||||
from iai_mcp.store import MemoryStore
|
||||
|
||||
store = MemoryStore(path=tmp_path)
|
||||
result = on_read_check_batch(store, [], session_id="test")
|
||||
assert result == []
|
||||
Loading…
Add table
Add a link
Reference in a new issue