Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: XNLLLLH <XNLLLLH@users.noreply.github.com>
211 lines
7 KiB
Python
211 lines
7 KiB
Python
"""Tests for s4.on_read_check_batch (Plan 02-07 D-SPEED gap closure).
|
|
|
|
D-SPEED contract: bench/neural_map p95<100ms at N=100. Root cause:
|
|
`s4.on_read_check` called per-hit inside pipeline_recall with no records_cache,
|
|
forcing N+1 store.get() round-trips. Fix: new `on_read_check_batch` that accepts
|
|
an optional records_cache from the caller and does ONE store.all_records() (or
|
|
zero if cache provided).
|
|
|
|
Equivalence contract: on_read_check_batch returns semantically identical hint
|
|
output to on_read_check for the same (store, hits, session_id) input. The
|
|
source_id contents of the returned hints must be a set-equal match; orderings
|
|
may differ because event-write side effects are intermingled.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from iai_mcp.types import EMBED_DIM, MemoryHit, MemoryRecord
|
|
|
|
|
|
def _make_record(
|
|
*,
|
|
text: str = "hello",
|
|
vec: list[float] | None = None,
|
|
tags: list[str] | None = None,
|
|
detail_level: int = 2,
|
|
tier: str = "episodic",
|
|
language: str = "en",
|
|
) -> MemoryRecord:
|
|
if vec is None:
|
|
vec = [1.0] + [0.0] * (EMBED_DIM - 1)
|
|
now = datetime.now(timezone.utc)
|
|
return MemoryRecord(
|
|
id=uuid4(),
|
|
tier=tier,
|
|
literal_surface=text,
|
|
aaak_index="",
|
|
embedding=vec,
|
|
community_id=None,
|
|
centrality=0.0,
|
|
detail_level=detail_level,
|
|
pinned=False,
|
|
stability=0.0,
|
|
difficulty=0.0,
|
|
last_reviewed=None,
|
|
never_decay=False,
|
|
never_merge=False,
|
|
provenance=[],
|
|
created_at=now,
|
|
updated_at=now,
|
|
tags=list(tags or []),
|
|
language=language,
|
|
)
|
|
|
|
|
|
def _hit_for(rec: MemoryRecord, score: float = 0.9) -> MemoryHit:
|
|
return MemoryHit(
|
|
record_id=rec.id,
|
|
score=score,
|
|
reason="test",
|
|
literal_surface=rec.literal_surface,
|
|
adjacent_suggestions=[],
|
|
)
|
|
|
|
|
|
# ------------------------------------------------------------- contract
|
|
|
|
|
|
def test_s4_exports_on_read_check_batch():
|
|
"""The batch variant exists and is callable."""
|
|
from iai_mcp import s4
|
|
|
|
assert hasattr(s4, "on_read_check_batch")
|
|
assert callable(s4.on_read_check_batch)
|
|
|
|
|
|
# ------------------------------------------------------------- behaviour
|
|
|
|
|
|
def test_on_read_check_batch_uses_records_cache(tmp_path):
|
|
"""When records_cache is passed, store.get is NOT called (zero round-trips).
|
|
|
|
This is the core D-SPEED fix: the caller (pipeline_recall) builds
|
|
records_cache at stage 1, so S4 must not re-fetch via store.get.
|
|
Monkeypatch store.get to raise; the call MUST succeed without exception.
|
|
"""
|
|
from iai_mcp.s4 import on_read_check_batch
|
|
from iai_mcp.store import MemoryStore
|
|
|
|
store = MemoryStore(path=tmp_path)
|
|
v1 = [0.0] * EMBED_DIM; v1[0] = 1.0
|
|
v2 = [0.0] * EMBED_DIM; v2[1] = 1.0
|
|
r1 = _make_record(text="X is true", vec=v1, tags=["claim"])
|
|
r2 = _make_record(text="X is false", vec=v2, tags=["claim"])
|
|
store.insert(r1)
|
|
store.insert(r2)
|
|
store.add_contradicts_edge(r1.id, r2.id)
|
|
|
|
records_cache = {r1.id: r1, r2.id: r2}
|
|
hits = [_hit_for(r1), _hit_for(r2)]
|
|
|
|
# If store.get is invoked at all, this test will raise.
|
|
def _boom(*args, **kwargs):
|
|
raise RuntimeError("store.get must not be called when records_cache is provided")
|
|
|
|
original_get = store.get
|
|
store.get = _boom # type: ignore[assignment]
|
|
try:
|
|
result = on_read_check_batch(
|
|
store, hits, session_id="test", records_cache=records_cache,
|
|
)
|
|
finally:
|
|
store.get = original_get # type: ignore[assignment]
|
|
|
|
# Contradicts-edge detection still fires.
|
|
assert len(result) == 1
|
|
assert set(result[0]["source_ids"]) == {str(r1.id), str(r2.id)}
|
|
|
|
|
|
def test_on_read_check_batch_fallback_no_cache(tmp_path):
|
|
"""Without records_cache, falls back to exactly one store.all_records() call.
|
|
|
|
Counts invocations via monkeypatched counter. store.get must not be called;
|
|
all_records must be called exactly once.
|
|
"""
|
|
from iai_mcp.s4 import on_read_check_batch
|
|
from iai_mcp.store import MemoryStore
|
|
|
|
store = MemoryStore(path=tmp_path)
|
|
v1 = [0.0] * EMBED_DIM; v1[0] = 1.0
|
|
v2 = [0.0] * EMBED_DIM; v2[1] = 1.0
|
|
r1 = _make_record(text="x", vec=v1)
|
|
r2 = _make_record(text="y", vec=v2)
|
|
store.insert(r1)
|
|
store.insert(r2)
|
|
|
|
get_calls = [0]
|
|
all_calls = [0]
|
|
original_get = store.get
|
|
original_all = store.all_records
|
|
|
|
def _counting_get(*a, **kw):
|
|
get_calls[0] += 1
|
|
return original_get(*a, **kw)
|
|
|
|
def _counting_all(*a, **kw):
|
|
all_calls[0] += 1
|
|
return original_all(*a, **kw)
|
|
|
|
store.get = _counting_get # type: ignore[assignment]
|
|
store.all_records = _counting_all # type: ignore[assignment]
|
|
try:
|
|
hits = [_hit_for(r1), _hit_for(r2)]
|
|
_ = on_read_check_batch(store, hits, session_id="test")
|
|
finally:
|
|
store.get = original_get # type: ignore[assignment]
|
|
store.all_records = original_all # type: ignore[assignment]
|
|
|
|
assert get_calls[0] == 0, f"store.get called {get_calls[0]} times (should be 0)"
|
|
assert all_calls[0] == 1, f"store.all_records called {all_calls[0]} times (should be 1)"
|
|
|
|
|
|
def test_batch_api_equivalence_on_detection(tmp_path):
|
|
"""on_read_check and on_read_check_batch return semantically-identical
|
|
hint output over the same (store, hits, session_id) input.
|
|
|
|
Comparison is over the (kind, frozenset(source_ids)) pair so that event
|
|
ordering / text wording differences don't invalidate parity.
|
|
"""
|
|
from iai_mcp.s4 import on_read_check, on_read_check_batch
|
|
from iai_mcp.store import MemoryStore
|
|
|
|
store = MemoryStore(path=tmp_path)
|
|
# Near-identical vectors + opposite polarity tags (cosine > 0.97)
|
|
v1 = [1.0] + [0.0] * (EMBED_DIM - 1)
|
|
v2 = [0.99] + [0.01] + [0.0] * (EMBED_DIM - 2)
|
|
r1 = _make_record(text="X good", vec=v1, tags=["topic", "positive"])
|
|
r2 = _make_record(text="X bad", vec=v2, tags=["topic", "negative"])
|
|
# Additionally a contradicts pair
|
|
v3 = [0.0] * EMBED_DIM; v3[2] = 1.0
|
|
v4 = [0.0] * EMBED_DIM; v4[3] = 1.0
|
|
r3 = _make_record(text="Y true", vec=v3, tags=["claim"])
|
|
r4 = _make_record(text="Y false", vec=v4, tags=["claim"])
|
|
for r in (r1, r2, r3, r4):
|
|
store.insert(r)
|
|
store.add_contradicts_edge(r3.id, r4.id)
|
|
|
|
hits = [_hit_for(r) for r in (r1, r2, r3, r4)]
|
|
|
|
single = on_read_check(store, hits, session_id="eq_test")
|
|
batch = on_read_check_batch(store, hits, session_id="eq_test")
|
|
|
|
def _key(h: dict) -> tuple[str, frozenset[str]]:
|
|
return (h["kind"], frozenset(h["source_ids"]))
|
|
|
|
assert {_key(h) for h in single} == {_key(h) for h in batch}
|
|
# Both should have detected at least 2 hints: polarity + contradicts.
|
|
assert len(batch) >= 2
|
|
|
|
|
|
def test_on_read_check_batch_empty_hits(tmp_path):
|
|
"""Empty hits list -> empty hints, no exception."""
|
|
from iai_mcp.s4 import on_read_check_batch
|
|
from iai_mcp.store import MemoryStore
|
|
|
|
store = MemoryStore(path=tmp_path)
|
|
result = on_read_check_batch(store, [], session_id="test")
|
|
assert result == []
|