search-kb: on-demand KB tool on the [n] spine; drop kb_matched_chunk_ids

The main agent's search_knowledge_base tool runs the hybrid spine, renders
a <retrieved_context> of numbered [n] passages, and persists the registry.
KB subagent prompts teach citing [n] from <document view="full"> reads
(evidence.chunk_ids -> evidence.citations). Delete the now-unused
search->read highlighting hand-off: the kb_matched_chunk_ids state field,
its reducer default, the tool's _matched_chunk_ids writer, and the dead
KnowledgePriorityMiddleware writes.
This commit is contained in:
CREDO23 2026-06-25 15:26:39 +02:00
parent 04a76b163b
commit c98bdea5cf
16 changed files with 518 additions and 325 deletions

View file

@ -0,0 +1,237 @@
"""Behavior tests for the ``search_knowledge_base`` main-agent tool.
These exercise the tool through its public contract: seed a real document,
invoke the tool, and assert on the ``Command`` it returns the rendered
``<retrieved_context>`` carries ``[n]`` labels and the citation registry handed
back on state is populated.
The tool's own DB session is redirected to the test session, and the embedding
leg is pinned so the search is deterministic without a live model.
"""
from __future__ import annotations
import contextlib
import uuid
from types import SimpleNamespace
import pytest
from langchain_core.messages import ToolMessage
from langgraph.types import Command
from app.agents.chat.multi_agent_chat.main_agent.tools import search_knowledge_base
from app.agents.chat.multi_agent_chat.main_agent.tools.search_knowledge_base import (
create_search_knowledge_base_tool,
)
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
from app.config import config
from app.db import Chunk, Document, DocumentType, Folder
pytestmark = pytest.mark.integration
_DIM = config.embedding_model_instance.dimension
def _axis(index: int) -> list[float]:
vector = [0.0] * _DIM
vector[index] = 1.0
return vector
async def _add_document(
db_session,
*,
search_space_id: int,
title: str,
text: str,
folder_id: int | None = None,
):
document = Document(
title=title,
document_type=DocumentType.FILE,
content=text,
content_hash=uuid.uuid4().hex,
search_space_id=search_space_id,
folder_id=folder_id,
status={"state": "ready"},
)
db_session.add(document)
await db_session.flush()
db_session.add(
Chunk(content=text, document_id=document.id, position=0, embedding=_axis(0))
)
await db_session.flush()
return document
async def _add_folder(db_session, *, search_space_id: int, name: str = "Folder"):
folder = Folder(name=name, position="0", search_space_id=search_space_id)
db_session.add(folder)
await db_session.flush()
return folder
@pytest.fixture
def _tool_uses_test_session(db_session, monkeypatch):
"""Redirect the tool's ``shielded_async_session`` to the test transaction."""
@contextlib.asynccontextmanager
async def _session():
yield db_session
monkeypatch.setattr(search_knowledge_base, "shielded_async_session", _session)
@pytest.fixture
def _pinned_embedding(monkeypatch):
monkeypatch.setattr(
config.embedding_model_instance, "embed", lambda _query: _axis(0)
)
async def _invoke(tool, query: str, state: dict | None = None, context=None):
runtime = SimpleNamespace(
state=state or {}, tool_call_id="call-1", context=context
)
return await tool.coroutine(query, runtime)
def _mentions(*, document_ids=(), folder_ids=()):
return SimpleNamespace(
mentioned_document_ids=list(document_ids),
mentioned_folder_ids=list(folder_ids),
)
async def test_tool_returns_retrieved_context_with_numbered_passages(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Asyncio Guide",
text="The asyncio library enables concurrency.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(tool, "asyncio")
assert isinstance(result, Command)
message = result.update["messages"][0]
assert isinstance(message, ToolMessage)
assert "<retrieved_context>" in message.content
assert "[1]" in message.content
async def test_tool_populates_citation_registry_on_state(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Asyncio Guide",
text="The asyncio library enables concurrency.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(tool, "asyncio")
registry = result.update["citation_registry"]
assert isinstance(registry, CitationRegistry)
assert registry.by_n # at least one passage was registered as [n]
async def test_tool_reuses_existing_registry_numbering(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Asyncio Guide",
text="The asyncio library enables concurrency.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
first = await _invoke(tool, "asyncio")
carried = first.update["citation_registry"]
second = await _invoke(tool, "asyncio", state={"citation_registry": carried})
# Same passage searched twice keeps a single [n] (find-or-create).
assert len(second.update["citation_registry"].by_n) == 1
async def test_tool_reports_no_matches_without_touching_state(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(tool, "nonexistent-term-zzz")
assert isinstance(result, str)
assert "No knowledge-base matches" in result
async def test_tool_rejects_empty_query(
db_search_space, _tool_uses_test_session, _pinned_embedding
):
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(tool, " ")
assert isinstance(result, str)
assert "non-empty" in result
async def test_document_mention_confines_search_to_pinned_doc(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
pinned = await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Pinned",
text="asyncio appears in the pinned doc.",
)
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Other",
text="asyncio appears in the other doc.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(
tool, "asyncio", context=_mentions(document_ids=[pinned.id])
)
# Search is confined to the pinned doc: only its content is rendered.
content = result.update["messages"][0].content
assert "Pinned" in content
assert "Other" not in content
async def test_folder_mention_confines_search_to_folder_documents(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
folder = await _add_folder(db_session, search_space_id=db_search_space.id)
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Inside",
text="asyncio appears inside the folder.",
folder_id=folder.id,
)
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Outside",
text="asyncio appears outside the folder.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(
tool, "asyncio", context=_mentions(folder_ids=[folder.id])
)
# Search is confined to the folder's document: only its content is rendered.
content = result.update["messages"][0].content
assert "Inside" in content
assert "Outside" not in content

View file

@ -4,9 +4,14 @@ from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.agents.chat.multi_agent_chat.shared.state.reducers import (
_CLEAR,
_add_unique_reducer,
_citation_registry_merge_reducer,
_dict_merge_with_tombstones_reducer,
_initial_filesystem_state,
_list_append_reducer,
@ -93,6 +98,57 @@ class TestDictMergeWithTombstones:
}
def _kb_registry(chunk_id: int) -> CitationRegistry:
registry = CitationRegistry()
registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
)
return registry
class TestCitationRegistryMergeReducer:
def test_none_left_returns_right(self):
right = _kb_registry(10)
assert _citation_registry_merge_reducer(None, right) is right
def test_none_right_returns_left(self):
left = _kb_registry(10)
assert _citation_registry_merge_reducer(left, None) is left
def test_both_none_returns_none(self):
assert _citation_registry_merge_reducer(None, None) is None
def test_unions_two_registries(self):
left = _kb_registry(10)
right = _kb_registry(11)
merged = _citation_registry_merge_reducer(left, right)
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
assert chunk_ids == {10, 11}
def test_coerces_serialized_dict_update(self):
# The checkpointer serializes Command.update via ormsgpack before the
# reducer runs, so `right` can arrive as a plain dict.
left = _kb_registry(10)
right = _kb_registry(11).model_dump()
merged = _citation_registry_merge_reducer(left, right)
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
assert chunk_ids == {10, 11}
def test_coerces_both_sides_from_dict(self):
left = _kb_registry(10).model_dump()
right = _kb_registry(11).model_dump()
merged = _citation_registry_merge_reducer(left, right)
assert isinstance(merged, CitationRegistry)
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
assert chunk_ids == {10, 11}
class TestInitialFilesystemState:
def test_default_shape(self):
state = _initial_filesystem_state()
@ -106,7 +162,6 @@ class TestInitialFilesystemState:
assert state["dirty_paths"] == []
assert state["dirty_path_tool_calls"] == {}
assert state["kb_priority"] == []
assert state["kb_matched_chunk_ids"] == {}
assert state["kb_anon_doc"] is None
assert state["tree_version"] == 0

View file

@ -6,9 +6,6 @@ import pytest
from langchain_core.messages import AIMessage, HumanMessage
from app.agents.chat.multi_agent_chat.shared.middleware import knowledge_search as ks
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.document_xml import (
build_document_xml as _build_document_xml,
)
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
KBSearchPlan,
KnowledgePriorityMiddleware,
@ -59,88 +56,6 @@ class TestResolveSearchTypes:
assert result.count("FILE") == 1
# ── _build_document_xml ────────────────────────────────────────────────
class TestBuildDocumentXml:
@pytest.fixture
def sample_document(self):
return {
"document_id": 42,
"document": {
"id": 42,
"document_type": "FILE",
"title": "Test Doc",
"metadata": {"url": "https://example.com"},
},
"chunks": [
{"chunk_id": 101, "content": "First chunk content"},
{"chunk_id": 102, "content": "Second chunk content"},
{"chunk_id": 103, "content": "Third chunk content"},
],
}
def test_contains_document_metadata(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<document_id>42</document_id>" in xml
assert "<document_type>FILE</document_type>" in xml
assert "Test Doc" in xml
def test_contains_chunk_index(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<chunk_index>" in xml
assert "</chunk_index>" in xml
assert 'chunk_id="101"' in xml
assert 'chunk_id="102"' in xml
assert 'chunk_id="103"' in xml
def test_matched_chunks_flagged_in_index(self, sample_document):
xml = _build_document_xml(sample_document, matched_chunk_ids={101, 103})
lines = xml.split("\n")
for line in lines:
if 'chunk_id="101"' in line:
assert 'matched="true"' in line
if 'chunk_id="102"' in line:
assert 'matched="true"' not in line
if 'chunk_id="103"' in line:
assert 'matched="true"' in line
def test_chunk_content_in_document_content_section(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<document_content>" in xml
assert "First chunk content" in xml
assert "Second chunk content" in xml
assert "Third chunk content" in xml
def test_line_numbers_in_chunk_index_are_accurate(self, sample_document):
"""Verify that the line ranges in chunk_index actually point to the right content."""
xml = _build_document_xml(sample_document, matched_chunk_ids={101})
xml_lines = xml.split("\n")
for line in xml_lines:
if 'chunk_id="101"' in line and "lines=" in line:
import re
m = re.search(r'lines="(\d+)-(\d+)"', line)
assert m, f"No lines= attribute found in: {line}"
start, _end = int(m.group(1)), int(m.group(2))
target_line = xml_lines[start - 1]
assert "101" in target_line
assert "First chunk content" in target_line
break
else:
pytest.fail("chunk_id=101 entry not found in chunk_index")
def test_splits_into_lines_correctly(self, sample_document):
"""Each chunk occupies exactly one line (no embedded newlines)."""
xml = _build_document_xml(sample_document)
lines = xml.split("\n")
chunk_lines = [
line for line in lines if "<![CDATA[" in line and "<chunk" in line
]
assert len(chunk_lines) == 3
# ── planner parsing / date normalization ───────────────────────────────