mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-07-04 22:02:16 +02:00
Merge pull request #1539 from CREDO23/improve-chat-agent-context-and-citations
[FEAT] Unified [n] citation registry for KB + web, pull-based retrieval
This commit is contained in:
commit
94fdb8a113
160 changed files with 4097 additions and 5238 deletions
|
|
@ -0,0 +1,237 @@
|
|||
"""Behavior tests for the ``search_knowledge_base`` main-agent tool.
|
||||
|
||||
These exercise the tool through its public contract: seed a real document,
|
||||
invoke the tool, and assert on the ``Command`` it returns — the rendered
|
||||
``<retrieved_context>`` carries ``[n]`` labels and the citation registry handed
|
||||
back on state is populated.
|
||||
The tool's own DB session is redirected to the test session, and the embedding
|
||||
leg is pinned so the search is deterministic without a live model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from app.agents.chat.multi_agent_chat.main_agent.tools import search_knowledge_base
|
||||
from app.agents.chat.multi_agent_chat.main_agent.tools.search_knowledge_base import (
|
||||
create_search_knowledge_base_tool,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
|
||||
from app.config import config
|
||||
from app.db import Chunk, Document, DocumentType, Folder
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_DIM = config.embedding_model_instance.dimension
|
||||
|
||||
|
||||
def _axis(index: int) -> list[float]:
|
||||
vector = [0.0] * _DIM
|
||||
vector[index] = 1.0
|
||||
return vector
|
||||
|
||||
|
||||
async def _add_document(
|
||||
db_session,
|
||||
*,
|
||||
search_space_id: int,
|
||||
title: str,
|
||||
text: str,
|
||||
folder_id: int | None = None,
|
||||
):
|
||||
document = Document(
|
||||
title=title,
|
||||
document_type=DocumentType.FILE,
|
||||
content=text,
|
||||
content_hash=uuid.uuid4().hex,
|
||||
search_space_id=search_space_id,
|
||||
folder_id=folder_id,
|
||||
status={"state": "ready"},
|
||||
)
|
||||
db_session.add(document)
|
||||
await db_session.flush()
|
||||
db_session.add(
|
||||
Chunk(content=text, document_id=document.id, position=0, embedding=_axis(0))
|
||||
)
|
||||
await db_session.flush()
|
||||
return document
|
||||
|
||||
|
||||
async def _add_folder(db_session, *, search_space_id: int, name: str = "Folder"):
|
||||
folder = Folder(name=name, position="0", search_space_id=search_space_id)
|
||||
db_session.add(folder)
|
||||
await db_session.flush()
|
||||
return folder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _tool_uses_test_session(db_session, monkeypatch):
|
||||
"""Redirect the tool's ``shielded_async_session`` to the test transaction."""
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _session():
|
||||
yield db_session
|
||||
|
||||
monkeypatch.setattr(search_knowledge_base, "shielded_async_session", _session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _pinned_embedding(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
config.embedding_model_instance, "embed", lambda _query: _axis(0)
|
||||
)
|
||||
|
||||
|
||||
async def _invoke(tool, query: str, state: dict | None = None, context=None):
|
||||
runtime = SimpleNamespace(
|
||||
state=state or {}, tool_call_id="call-1", context=context
|
||||
)
|
||||
return await tool.coroutine(query, runtime)
|
||||
|
||||
|
||||
def _mentions(*, document_ids=(), folder_ids=()):
|
||||
return SimpleNamespace(
|
||||
mentioned_document_ids=list(document_ids),
|
||||
mentioned_folder_ids=list(folder_ids),
|
||||
)
|
||||
|
||||
|
||||
async def test_tool_returns_retrieved_context_with_numbered_passages(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Asyncio Guide",
|
||||
text="The asyncio library enables concurrency.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(tool, "asyncio")
|
||||
|
||||
assert isinstance(result, Command)
|
||||
message = result.update["messages"][0]
|
||||
assert isinstance(message, ToolMessage)
|
||||
assert "<retrieved_context>" in message.content
|
||||
assert "[1]" in message.content
|
||||
|
||||
|
||||
async def test_tool_populates_citation_registry_on_state(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Asyncio Guide",
|
||||
text="The asyncio library enables concurrency.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(tool, "asyncio")
|
||||
|
||||
registry = result.update["citation_registry"]
|
||||
assert isinstance(registry, CitationRegistry)
|
||||
assert registry.by_n # at least one passage was registered as [n]
|
||||
|
||||
|
||||
async def test_tool_reuses_existing_registry_numbering(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Asyncio Guide",
|
||||
text="The asyncio library enables concurrency.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
first = await _invoke(tool, "asyncio")
|
||||
carried = first.update["citation_registry"]
|
||||
second = await _invoke(tool, "asyncio", state={"citation_registry": carried})
|
||||
|
||||
# Same passage searched twice keeps a single [n] (find-or-create).
|
||||
assert len(second.update["citation_registry"].by_n) == 1
|
||||
|
||||
|
||||
async def test_tool_reports_no_matches_without_touching_state(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(tool, "nonexistent-term-zzz")
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "No knowledge-base matches" in result
|
||||
|
||||
|
||||
async def test_tool_rejects_empty_query(
|
||||
db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(tool, " ")
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "non-empty" in result
|
||||
|
||||
|
||||
async def test_document_mention_confines_search_to_pinned_doc(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
pinned = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Pinned",
|
||||
text="asyncio appears in the pinned doc.",
|
||||
)
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Other",
|
||||
text="asyncio appears in the other doc.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(
|
||||
tool, "asyncio", context=_mentions(document_ids=[pinned.id])
|
||||
)
|
||||
|
||||
# Search is confined to the pinned doc: only its content is rendered.
|
||||
content = result.update["messages"][0].content
|
||||
assert "Pinned" in content
|
||||
assert "Other" not in content
|
||||
|
||||
|
||||
async def test_folder_mention_confines_search_to_folder_documents(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
folder = await _add_folder(db_session, search_space_id=db_search_space.id)
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Inside",
|
||||
text="asyncio appears inside the folder.",
|
||||
folder_id=folder.id,
|
||||
)
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Outside",
|
||||
text="asyncio appears outside the folder.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(
|
||||
tool, "asyncio", context=_mentions(folder_ids=[folder.id])
|
||||
)
|
||||
|
||||
# Search is confined to the folder's document: only its content is rendered.
|
||||
content = result.update["messages"][0].content
|
||||
assert "Inside" in content
|
||||
assert "Outside" not in content
|
||||
|
|
@ -0,0 +1,236 @@
|
|||
"""Behavior tests for the hybrid chunk retriever against a real Postgres.
|
||||
|
||||
These exercise ``search_chunks`` through its public surface only: seed real
|
||||
documents/chunks, run a search, and assert on the returned ``DocumentHit``s —
|
||||
never on SQL shape or internal ranking math. ``query_embedding`` is supplied
|
||||
directly (a public parameter) so the semantic leg is deterministic instead of
|
||||
depending on a live embedding model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.retrieval.hybrid_search import (
|
||||
search_chunks,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.retrieval.models import SearchScope
|
||||
from app.config import config
|
||||
from app.db import Chunk, Document, DocumentType, SearchSpace
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_DIM = config.embedding_model_instance.dimension
|
||||
|
||||
|
||||
def _axis(index: int) -> list[float]:
|
||||
"""A unit vector pointing along one axis — orthogonal axes are dissimilar."""
|
||||
vector = [0.0] * _DIM
|
||||
vector[index] = 1.0
|
||||
return vector
|
||||
|
||||
|
||||
async def _add_document(
|
||||
db_session,
|
||||
*,
|
||||
search_space_id: int,
|
||||
title: str = "Doc",
|
||||
document_type: DocumentType = DocumentType.FILE,
|
||||
state: str = "ready",
|
||||
chunks: list[tuple[str, int, list[float]]],
|
||||
) -> Document:
|
||||
"""Persist one document and its chunks; ``chunks`` is (content, position, embedding)."""
|
||||
document = Document(
|
||||
title=title,
|
||||
document_type=document_type,
|
||||
content="\n".join(content for content, _, _ in chunks),
|
||||
content_hash=uuid.uuid4().hex,
|
||||
search_space_id=search_space_id,
|
||||
status={"state": state},
|
||||
)
|
||||
db_session.add(document)
|
||||
await db_session.flush()
|
||||
for content, position, embedding in chunks:
|
||||
db_session.add(
|
||||
Chunk(
|
||||
content=content,
|
||||
document_id=document.id,
|
||||
position=position,
|
||||
embedding=embedding,
|
||||
)
|
||||
)
|
||||
await db_session.flush()
|
||||
return document
|
||||
|
||||
|
||||
async def test_keyword_relevant_document_is_retrieved(db_session, db_search_space):
|
||||
document = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Asyncio Guide",
|
||||
chunks=[("The asyncio library enables concurrency.", 0, _axis(0))],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asyncio",
|
||||
scope=SearchScope(),
|
||||
top_k=5,
|
||||
query_embedding=_axis(99),
|
||||
)
|
||||
|
||||
assert document.id in {hit.document_id for hit in results}
|
||||
|
||||
|
||||
async def test_semantically_closest_document_ranks_first(db_session, db_search_space):
|
||||
aligned = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Background Work",
|
||||
chunks=[("Parallel execution of background work.", 0, _axis(0))],
|
||||
)
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Dessert",
|
||||
chunks=[("Recipes for chocolate cake.", 0, _axis(1))],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asynchronous coroutines",
|
||||
scope=SearchScope(),
|
||||
top_k=5,
|
||||
query_embedding=_axis(0),
|
||||
)
|
||||
|
||||
assert results[0].document_id == aligned.id
|
||||
|
||||
|
||||
async def test_results_stay_within_the_search_space(db_session, db_search_space):
|
||||
other_space = SearchSpace(name="Other Space", user_id=db_search_space.user_id)
|
||||
db_session.add(other_space)
|
||||
await db_session.flush()
|
||||
|
||||
mine = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
chunks=[("Shared keyword asyncio here.", 0, _axis(0))],
|
||||
)
|
||||
foreign = await _add_document(
|
||||
db_session,
|
||||
search_space_id=other_space.id,
|
||||
chunks=[("Shared keyword asyncio here.", 0, _axis(0))],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asyncio",
|
||||
scope=SearchScope(),
|
||||
top_k=5,
|
||||
query_embedding=_axis(0),
|
||||
)
|
||||
|
||||
found = {hit.document_id for hit in results}
|
||||
assert mine.id in found and foreign.id not in found
|
||||
|
||||
|
||||
async def test_document_ids_scope_pins_results(db_session, db_search_space):
|
||||
pinned = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
chunks=[("asyncio appears in the pinned doc.", 0, _axis(0))],
|
||||
)
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
chunks=[("asyncio appears in the other doc too.", 0, _axis(0))],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asyncio",
|
||||
scope=SearchScope(document_ids=(pinned.id,)),
|
||||
top_k=5,
|
||||
query_embedding=_axis(0),
|
||||
)
|
||||
|
||||
assert {hit.document_id for hit in results} == {pinned.id}
|
||||
|
||||
|
||||
async def test_deleting_documents_are_excluded(db_session, db_search_space):
|
||||
ready = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
chunks=[("asyncio in a ready document.", 0, _axis(0))],
|
||||
)
|
||||
deleting = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
state="deleting",
|
||||
chunks=[("asyncio in a deleting document.", 0, _axis(0))],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asyncio",
|
||||
scope=SearchScope(),
|
||||
top_k=5,
|
||||
query_embedding=_axis(0),
|
||||
)
|
||||
|
||||
found = {hit.document_id for hit in results}
|
||||
assert ready.id in found and deleting.id not in found
|
||||
|
||||
|
||||
async def test_matched_chunks_are_ordered_for_reading(db_session, db_search_space):
|
||||
# Insert out of order, and give the later-position chunk the stronger
|
||||
# semantic score, so reading order differs from both insertion and score.
|
||||
document = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
chunks=[
|
||||
("asyncio paragraph two.", 1, _axis(0)),
|
||||
("asyncio paragraph one.", 0, _axis(50)),
|
||||
],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asyncio",
|
||||
scope=SearchScope(),
|
||||
top_k=5,
|
||||
query_embedding=_axis(0),
|
||||
)
|
||||
|
||||
hit = next(hit for hit in results if hit.document_id == document.id)
|
||||
assert [chunk.position for chunk in hit.chunks] == [0, 1]
|
||||
|
||||
|
||||
async def test_top_k_caps_the_number_of_documents(db_session, db_search_space):
|
||||
for index in range(3):
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title=f"Doc {index}",
|
||||
chunks=[(f"asyncio mentioned in doc {index}.", 0, _axis(index))],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asyncio",
|
||||
scope=SearchScope(),
|
||||
top_k=2,
|
||||
query_embedding=_axis(0),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
|
@ -3,7 +3,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
|
@ -227,23 +226,6 @@ def patched_embed(monkeypatch):
|
|||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_shielded_session(async_engine, monkeypatch):
|
||||
"""Replace ``shielded_async_session`` in the knowledge_base module
|
||||
with one that yields sessions from the test engine."""
|
||||
test_maker = async_sessionmaker(async_engine, expire_on_commit=False)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _test_shielded():
|
||||
async with test_maker() as session:
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.knowledge_base.shielded_async_session",
|
||||
_test_shielded,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Indexer test helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
"""Integration test: _browse_recent_documents returns docs of multiple types.
|
||||
|
||||
Exercises the browse path (degenerate-query fallback) with a real PostgreSQL
|
||||
database. Verifies that passing a list of document types correctly returns
|
||||
documents of all listed types -- the same ``.in_()`` SQL path used by hybrid
|
||||
search but through the browse/recency-ordered code path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_browse_recent_documents_with_list_type_returns_both(
|
||||
committed_google_data, patched_shielded_session
|
||||
):
|
||||
"""_browse_recent_documents returns docs of all types when given a list."""
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.knowledge_base import (
|
||||
_browse_recent_documents,
|
||||
)
|
||||
|
||||
space_id = committed_google_data["search_space_id"]
|
||||
|
||||
results = await _browse_recent_documents(
|
||||
search_space_id=space_id,
|
||||
document_type=["GOOGLE_DRIVE_FILE", "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"],
|
||||
top_k=10,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
)
|
||||
|
||||
returned_types = set()
|
||||
for doc in results:
|
||||
doc_info = doc.get("document", {})
|
||||
dtype = doc_info.get("document_type")
|
||||
if dtype:
|
||||
returned_types.add(dtype)
|
||||
|
||||
assert "GOOGLE_DRIVE_FILE" in returned_types, (
|
||||
"Native Drive docs should appear in browse results"
|
||||
)
|
||||
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in returned_types, (
|
||||
"Legacy Composio Drive docs should appear in browse results"
|
||||
)
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
"""Integration smoke tests for KB search query/date scoping."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware import knowledge_search as ks
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
|
||||
search_knowledge_base,
|
||||
)
|
||||
|
||||
from .conftest import DUMMY_EMBEDDING
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_search_knowledge_base_applies_date_filters(
|
||||
db_session,
|
||||
seed_date_filtered_docs,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Date filters should remove older matching documents from scoped KB results."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_shielded_async_session():
|
||||
yield db_session
|
||||
|
||||
monkeypatch.setattr(ks, "shielded_async_session", fake_shielded_async_session)
|
||||
monkeypatch.setattr(
|
||||
ks, "embed_texts", lambda texts: [np.array(DUMMY_EMBEDDING) for _ in texts]
|
||||
)
|
||||
|
||||
space_id = seed_date_filtered_docs["search_space"].id
|
||||
recent_cutoff = datetime.now(UTC) - timedelta(days=30)
|
||||
|
||||
unfiltered_results = await search_knowledge_base(
|
||||
query="ocv meeting decisions",
|
||||
search_space_id=space_id,
|
||||
available_document_types=["FILE"],
|
||||
top_k=10,
|
||||
)
|
||||
filtered_results = await search_knowledge_base(
|
||||
query="ocv meeting decisions",
|
||||
search_space_id=space_id,
|
||||
available_document_types=["FILE"],
|
||||
top_k=10,
|
||||
start_date=recent_cutoff,
|
||||
end_date=datetime.now(UTC),
|
||||
)
|
||||
|
||||
unfiltered_ids = {result["document"]["id"] for result in unfiltered_results}
|
||||
filtered_ids = {result["document"]["id"] for result in filtered_results}
|
||||
|
||||
assert seed_date_filtered_docs["recent_doc"].id in unfiltered_ids
|
||||
assert seed_date_filtered_docs["old_doc"].id in unfiltered_ids
|
||||
assert seed_date_filtered_docs["recent_doc"].id in filtered_ids
|
||||
assert seed_date_filtered_docs["old_doc"].id not in filtered_ids
|
||||
Loading…
Add table
Add a link
Reference in a new issue