mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
search-kb: on-demand KB tool on the [n] spine; drop kb_matched_chunk_ids
The main agent's search_knowledge_base tool runs the hybrid spine, renders a <retrieved_context> of numbered [n] passages, and persists the registry. KB subagent prompts teach citing [n] from <document view="full"> reads (evidence.chunk_ids -> evidence.citations). Delete the now-unused search->read highlighting hand-off: the kb_matched_chunk_ids state field, its reducer default, the tool's _matched_chunk_ids writer, and the dead KnowledgePriorityMiddleware writes.
This commit is contained in:
parent
04a76b163b
commit
c98bdea5cf
16 changed files with 518 additions and 325 deletions
|
|
@ -0,0 +1,237 @@
|
|||
"""Behavior tests for the ``search_knowledge_base`` main-agent tool.
|
||||
|
||||
These exercise the tool through its public contract: seed a real document,
|
||||
invoke the tool, and assert on the ``Command`` it returns — the rendered
|
||||
``<retrieved_context>`` carries ``[n]`` labels and the citation registry handed
|
||||
back on state is populated.
|
||||
The tool's own DB session is redirected to the test session, and the embedding
|
||||
leg is pinned so the search is deterministic without a live model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from app.agents.chat.multi_agent_chat.main_agent.tools import search_knowledge_base
|
||||
from app.agents.chat.multi_agent_chat.main_agent.tools.search_knowledge_base import (
|
||||
create_search_knowledge_base_tool,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
|
||||
from app.config import config
|
||||
from app.db import Chunk, Document, DocumentType, Folder
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_DIM = config.embedding_model_instance.dimension
|
||||
|
||||
|
||||
def _axis(index: int) -> list[float]:
|
||||
vector = [0.0] * _DIM
|
||||
vector[index] = 1.0
|
||||
return vector
|
||||
|
||||
|
||||
async def _add_document(
|
||||
db_session,
|
||||
*,
|
||||
search_space_id: int,
|
||||
title: str,
|
||||
text: str,
|
||||
folder_id: int | None = None,
|
||||
):
|
||||
document = Document(
|
||||
title=title,
|
||||
document_type=DocumentType.FILE,
|
||||
content=text,
|
||||
content_hash=uuid.uuid4().hex,
|
||||
search_space_id=search_space_id,
|
||||
folder_id=folder_id,
|
||||
status={"state": "ready"},
|
||||
)
|
||||
db_session.add(document)
|
||||
await db_session.flush()
|
||||
db_session.add(
|
||||
Chunk(content=text, document_id=document.id, position=0, embedding=_axis(0))
|
||||
)
|
||||
await db_session.flush()
|
||||
return document
|
||||
|
||||
|
||||
async def _add_folder(db_session, *, search_space_id: int, name: str = "Folder"):
|
||||
folder = Folder(name=name, position="0", search_space_id=search_space_id)
|
||||
db_session.add(folder)
|
||||
await db_session.flush()
|
||||
return folder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _tool_uses_test_session(db_session, monkeypatch):
|
||||
"""Redirect the tool's ``shielded_async_session`` to the test transaction."""
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _session():
|
||||
yield db_session
|
||||
|
||||
monkeypatch.setattr(search_knowledge_base, "shielded_async_session", _session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _pinned_embedding(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
config.embedding_model_instance, "embed", lambda _query: _axis(0)
|
||||
)
|
||||
|
||||
|
||||
async def _invoke(tool, query: str, state: dict | None = None, context=None):
|
||||
runtime = SimpleNamespace(
|
||||
state=state or {}, tool_call_id="call-1", context=context
|
||||
)
|
||||
return await tool.coroutine(query, runtime)
|
||||
|
||||
|
||||
def _mentions(*, document_ids=(), folder_ids=()):
|
||||
return SimpleNamespace(
|
||||
mentioned_document_ids=list(document_ids),
|
||||
mentioned_folder_ids=list(folder_ids),
|
||||
)
|
||||
|
||||
|
||||
async def test_tool_returns_retrieved_context_with_numbered_passages(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Asyncio Guide",
|
||||
text="The asyncio library enables concurrency.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(tool, "asyncio")
|
||||
|
||||
assert isinstance(result, Command)
|
||||
message = result.update["messages"][0]
|
||||
assert isinstance(message, ToolMessage)
|
||||
assert "<retrieved_context>" in message.content
|
||||
assert "[1]" in message.content
|
||||
|
||||
|
||||
async def test_tool_populates_citation_registry_on_state(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Asyncio Guide",
|
||||
text="The asyncio library enables concurrency.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(tool, "asyncio")
|
||||
|
||||
registry = result.update["citation_registry"]
|
||||
assert isinstance(registry, CitationRegistry)
|
||||
assert registry.by_n # at least one passage was registered as [n]
|
||||
|
||||
|
||||
async def test_tool_reuses_existing_registry_numbering(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Asyncio Guide",
|
||||
text="The asyncio library enables concurrency.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
first = await _invoke(tool, "asyncio")
|
||||
carried = first.update["citation_registry"]
|
||||
second = await _invoke(tool, "asyncio", state={"citation_registry": carried})
|
||||
|
||||
# Same passage searched twice keeps a single [n] (find-or-create).
|
||||
assert len(second.update["citation_registry"].by_n) == 1
|
||||
|
||||
|
||||
async def test_tool_reports_no_matches_without_touching_state(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(tool, "nonexistent-term-zzz")
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "No knowledge-base matches" in result
|
||||
|
||||
|
||||
async def test_tool_rejects_empty_query(
|
||||
db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(tool, " ")
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "non-empty" in result
|
||||
|
||||
|
||||
async def test_document_mention_confines_search_to_pinned_doc(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
pinned = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Pinned",
|
||||
text="asyncio appears in the pinned doc.",
|
||||
)
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Other",
|
||||
text="asyncio appears in the other doc.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(
|
||||
tool, "asyncio", context=_mentions(document_ids=[pinned.id])
|
||||
)
|
||||
|
||||
# Search is confined to the pinned doc: only its content is rendered.
|
||||
content = result.update["messages"][0].content
|
||||
assert "Pinned" in content
|
||||
assert "Other" not in content
|
||||
|
||||
|
||||
async def test_folder_mention_confines_search_to_folder_documents(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
folder = await _add_folder(db_session, search_space_id=db_search_space.id)
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Inside",
|
||||
text="asyncio appears inside the folder.",
|
||||
folder_id=folder.id,
|
||||
)
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Outside",
|
||||
text="asyncio appears outside the folder.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(
|
||||
tool, "asyncio", context=_mentions(folder_ids=[folder.id])
|
||||
)
|
||||
|
||||
# Search is confined to the folder's document: only its content is rendered.
|
||||
content = result.update["messages"][0].content
|
||||
assert "Inside" in content
|
||||
assert "Outside" not in content
|
||||
|
|
@ -4,9 +4,14 @@ from __future__ import annotations
|
|||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import (
|
||||
CitationRegistry,
|
||||
CitationSourceType,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.state.reducers import (
|
||||
_CLEAR,
|
||||
_add_unique_reducer,
|
||||
_citation_registry_merge_reducer,
|
||||
_dict_merge_with_tombstones_reducer,
|
||||
_initial_filesystem_state,
|
||||
_list_append_reducer,
|
||||
|
|
@ -93,6 +98,57 @@ class TestDictMergeWithTombstones:
|
|||
}
|
||||
|
||||
|
||||
def _kb_registry(chunk_id: int) -> CitationRegistry:
|
||||
registry = CitationRegistry()
|
||||
registry.register(
|
||||
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
|
||||
)
|
||||
return registry
|
||||
|
||||
|
||||
class TestCitationRegistryMergeReducer:
|
||||
def test_none_left_returns_right(self):
|
||||
right = _kb_registry(10)
|
||||
assert _citation_registry_merge_reducer(None, right) is right
|
||||
|
||||
def test_none_right_returns_left(self):
|
||||
left = _kb_registry(10)
|
||||
assert _citation_registry_merge_reducer(left, None) is left
|
||||
|
||||
def test_both_none_returns_none(self):
|
||||
assert _citation_registry_merge_reducer(None, None) is None
|
||||
|
||||
def test_unions_two_registries(self):
|
||||
left = _kb_registry(10)
|
||||
right = _kb_registry(11)
|
||||
|
||||
merged = _citation_registry_merge_reducer(left, right)
|
||||
|
||||
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
|
||||
assert chunk_ids == {10, 11}
|
||||
|
||||
def test_coerces_serialized_dict_update(self):
|
||||
# The checkpointer serializes Command.update via ormsgpack before the
|
||||
# reducer runs, so `right` can arrive as a plain dict.
|
||||
left = _kb_registry(10)
|
||||
right = _kb_registry(11).model_dump()
|
||||
|
||||
merged = _citation_registry_merge_reducer(left, right)
|
||||
|
||||
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
|
||||
assert chunk_ids == {10, 11}
|
||||
|
||||
def test_coerces_both_sides_from_dict(self):
|
||||
left = _kb_registry(10).model_dump()
|
||||
right = _kb_registry(11).model_dump()
|
||||
|
||||
merged = _citation_registry_merge_reducer(left, right)
|
||||
|
||||
assert isinstance(merged, CitationRegistry)
|
||||
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
|
||||
assert chunk_ids == {10, 11}
|
||||
|
||||
|
||||
class TestInitialFilesystemState:
|
||||
def test_default_shape(self):
|
||||
state = _initial_filesystem_state()
|
||||
|
|
@ -106,7 +162,6 @@ class TestInitialFilesystemState:
|
|||
assert state["dirty_paths"] == []
|
||||
assert state["dirty_path_tool_calls"] == {}
|
||||
assert state["kb_priority"] == []
|
||||
assert state["kb_matched_chunk_ids"] == {}
|
||||
assert state["kb_anon_doc"] is None
|
||||
assert state["tree_version"] == 0
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,6 @@ import pytest
|
|||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware import knowledge_search as ks
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.document_xml import (
|
||||
build_document_xml as _build_document_xml,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
|
||||
KBSearchPlan,
|
||||
KnowledgePriorityMiddleware,
|
||||
|
|
@ -59,88 +56,6 @@ class TestResolveSearchTypes:
|
|||
assert result.count("FILE") == 1
|
||||
|
||||
|
||||
# ── _build_document_xml ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBuildDocumentXml:
|
||||
@pytest.fixture
|
||||
def sample_document(self):
|
||||
return {
|
||||
"document_id": 42,
|
||||
"document": {
|
||||
"id": 42,
|
||||
"document_type": "FILE",
|
||||
"title": "Test Doc",
|
||||
"metadata": {"url": "https://example.com"},
|
||||
},
|
||||
"chunks": [
|
||||
{"chunk_id": 101, "content": "First chunk content"},
|
||||
{"chunk_id": 102, "content": "Second chunk content"},
|
||||
{"chunk_id": 103, "content": "Third chunk content"},
|
||||
],
|
||||
}
|
||||
|
||||
def test_contains_document_metadata(self, sample_document):
|
||||
xml = _build_document_xml(sample_document)
|
||||
assert "<document_id>42</document_id>" in xml
|
||||
assert "<document_type>FILE</document_type>" in xml
|
||||
assert "Test Doc" in xml
|
||||
|
||||
def test_contains_chunk_index(self, sample_document):
|
||||
xml = _build_document_xml(sample_document)
|
||||
assert "<chunk_index>" in xml
|
||||
assert "</chunk_index>" in xml
|
||||
assert 'chunk_id="101"' in xml
|
||||
assert 'chunk_id="102"' in xml
|
||||
assert 'chunk_id="103"' in xml
|
||||
|
||||
def test_matched_chunks_flagged_in_index(self, sample_document):
|
||||
xml = _build_document_xml(sample_document, matched_chunk_ids={101, 103})
|
||||
lines = xml.split("\n")
|
||||
for line in lines:
|
||||
if 'chunk_id="101"' in line:
|
||||
assert 'matched="true"' in line
|
||||
if 'chunk_id="102"' in line:
|
||||
assert 'matched="true"' not in line
|
||||
if 'chunk_id="103"' in line:
|
||||
assert 'matched="true"' in line
|
||||
|
||||
def test_chunk_content_in_document_content_section(self, sample_document):
|
||||
xml = _build_document_xml(sample_document)
|
||||
assert "<document_content>" in xml
|
||||
assert "First chunk content" in xml
|
||||
assert "Second chunk content" in xml
|
||||
assert "Third chunk content" in xml
|
||||
|
||||
def test_line_numbers_in_chunk_index_are_accurate(self, sample_document):
|
||||
"""Verify that the line ranges in chunk_index actually point to the right content."""
|
||||
xml = _build_document_xml(sample_document, matched_chunk_ids={101})
|
||||
xml_lines = xml.split("\n")
|
||||
|
||||
for line in xml_lines:
|
||||
if 'chunk_id="101"' in line and "lines=" in line:
|
||||
import re
|
||||
|
||||
m = re.search(r'lines="(\d+)-(\d+)"', line)
|
||||
assert m, f"No lines= attribute found in: {line}"
|
||||
start, _end = int(m.group(1)), int(m.group(2))
|
||||
target_line = xml_lines[start - 1]
|
||||
assert "101" in target_line
|
||||
assert "First chunk content" in target_line
|
||||
break
|
||||
else:
|
||||
pytest.fail("chunk_id=101 entry not found in chunk_index")
|
||||
|
||||
def test_splits_into_lines_correctly(self, sample_document):
|
||||
"""Each chunk occupies exactly one line (no embedded newlines)."""
|
||||
xml = _build_document_xml(sample_document)
|
||||
lines = xml.split("\n")
|
||||
chunk_lines = [
|
||||
line for line in lines if "<![CDATA[" in line and "<chunk" in line
|
||||
]
|
||||
assert len(chunk_lines) == 3
|
||||
|
||||
|
||||
# ── planner parsing / date normalization ───────────────────────────────
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue