Merge pull request #1539 from CREDO23/improve-chat-agent-context-and-citations

[FEAT] Unified [n] citation registry for KB + web, pull-based retrieval
This commit is contained in:
Rohan Verma 2026-06-25 13:34:52 -07:00 committed by GitHub
commit 94fdb8a113
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
160 changed files with 4097 additions and 5238 deletions

View file

@ -0,0 +1,237 @@
"""Behavior tests for the ``search_knowledge_base`` main-agent tool.
These exercise the tool through its public contract: seed a real document,
invoke the tool, and assert on the ``Command`` it returns the rendered
``<retrieved_context>`` carries ``[n]`` labels and the citation registry handed
back on state is populated.
The tool's own DB session is redirected to the test session, and the embedding
leg is pinned so the search is deterministic without a live model.
"""
from __future__ import annotations
import contextlib
import uuid
from types import SimpleNamespace
import pytest
from langchain_core.messages import ToolMessage
from langgraph.types import Command
from app.agents.chat.multi_agent_chat.main_agent.tools import search_knowledge_base
from app.agents.chat.multi_agent_chat.main_agent.tools.search_knowledge_base import (
create_search_knowledge_base_tool,
)
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
from app.config import config
from app.db import Chunk, Document, DocumentType, Folder
pytestmark = pytest.mark.integration
_DIM = config.embedding_model_instance.dimension
def _axis(index: int) -> list[float]:
vector = [0.0] * _DIM
vector[index] = 1.0
return vector
async def _add_document(
db_session,
*,
search_space_id: int,
title: str,
text: str,
folder_id: int | None = None,
):
document = Document(
title=title,
document_type=DocumentType.FILE,
content=text,
content_hash=uuid.uuid4().hex,
search_space_id=search_space_id,
folder_id=folder_id,
status={"state": "ready"},
)
db_session.add(document)
await db_session.flush()
db_session.add(
Chunk(content=text, document_id=document.id, position=0, embedding=_axis(0))
)
await db_session.flush()
return document
async def _add_folder(db_session, *, search_space_id: int, name: str = "Folder"):
folder = Folder(name=name, position="0", search_space_id=search_space_id)
db_session.add(folder)
await db_session.flush()
return folder
@pytest.fixture
def _tool_uses_test_session(db_session, monkeypatch):
"""Redirect the tool's ``shielded_async_session`` to the test transaction."""
@contextlib.asynccontextmanager
async def _session():
yield db_session
monkeypatch.setattr(search_knowledge_base, "shielded_async_session", _session)
@pytest.fixture
def _pinned_embedding(monkeypatch):
monkeypatch.setattr(
config.embedding_model_instance, "embed", lambda _query: _axis(0)
)
async def _invoke(tool, query: str, state: dict | None = None, context=None):
runtime = SimpleNamespace(
state=state or {}, tool_call_id="call-1", context=context
)
return await tool.coroutine(query, runtime)
def _mentions(*, document_ids=(), folder_ids=()):
return SimpleNamespace(
mentioned_document_ids=list(document_ids),
mentioned_folder_ids=list(folder_ids),
)
async def test_tool_returns_retrieved_context_with_numbered_passages(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Asyncio Guide",
text="The asyncio library enables concurrency.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(tool, "asyncio")
assert isinstance(result, Command)
message = result.update["messages"][0]
assert isinstance(message, ToolMessage)
assert "<retrieved_context>" in message.content
assert "[1]" in message.content
async def test_tool_populates_citation_registry_on_state(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Asyncio Guide",
text="The asyncio library enables concurrency.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(tool, "asyncio")
registry = result.update["citation_registry"]
assert isinstance(registry, CitationRegistry)
assert registry.by_n # at least one passage was registered as [n]
async def test_tool_reuses_existing_registry_numbering(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Asyncio Guide",
text="The asyncio library enables concurrency.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
first = await _invoke(tool, "asyncio")
carried = first.update["citation_registry"]
second = await _invoke(tool, "asyncio", state={"citation_registry": carried})
# Same passage searched twice keeps a single [n] (find-or-create).
assert len(second.update["citation_registry"].by_n) == 1
async def test_tool_reports_no_matches_without_touching_state(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(tool, "nonexistent-term-zzz")
assert isinstance(result, str)
assert "No knowledge-base matches" in result
async def test_tool_rejects_empty_query(
db_search_space, _tool_uses_test_session, _pinned_embedding
):
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(tool, " ")
assert isinstance(result, str)
assert "non-empty" in result
async def test_document_mention_confines_search_to_pinned_doc(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
pinned = await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Pinned",
text="asyncio appears in the pinned doc.",
)
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Other",
text="asyncio appears in the other doc.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(
tool, "asyncio", context=_mentions(document_ids=[pinned.id])
)
# Search is confined to the pinned doc: only its content is rendered.
content = result.update["messages"][0].content
assert "Pinned" in content
assert "Other" not in content
async def test_folder_mention_confines_search_to_folder_documents(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
folder = await _add_folder(db_session, search_space_id=db_search_space.id)
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Inside",
text="asyncio appears inside the folder.",
folder_id=folder.id,
)
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Outside",
text="asyncio appears outside the folder.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(
tool, "asyncio", context=_mentions(folder_ids=[folder.id])
)
# Search is confined to the folder's document: only its content is rendered.
content = result.update["messages"][0].content
assert "Inside" in content
assert "Outside" not in content

View file

@ -0,0 +1,236 @@
"""Behavior tests for the hybrid chunk retriever against a real Postgres.
These exercise ``search_chunks`` through its public surface only: seed real
documents/chunks, run a search, and assert on the returned ``DocumentHit``s
never on SQL shape or internal ranking math. ``query_embedding`` is supplied
directly (a public parameter) so the semantic leg is deterministic instead of
depending on a live embedding model.
"""
from __future__ import annotations
import uuid
import pytest
from app.agents.chat.multi_agent_chat.shared.retrieval.hybrid_search import (
search_chunks,
)
from app.agents.chat.multi_agent_chat.shared.retrieval.models import SearchScope
from app.config import config
from app.db import Chunk, Document, DocumentType, SearchSpace
pytestmark = pytest.mark.integration
_DIM = config.embedding_model_instance.dimension
def _axis(index: int) -> list[float]:
"""A unit vector pointing along one axis — orthogonal axes are dissimilar."""
vector = [0.0] * _DIM
vector[index] = 1.0
return vector
async def _add_document(
db_session,
*,
search_space_id: int,
title: str = "Doc",
document_type: DocumentType = DocumentType.FILE,
state: str = "ready",
chunks: list[tuple[str, int, list[float]]],
) -> Document:
"""Persist one document and its chunks; ``chunks`` is (content, position, embedding)."""
document = Document(
title=title,
document_type=document_type,
content="\n".join(content for content, _, _ in chunks),
content_hash=uuid.uuid4().hex,
search_space_id=search_space_id,
status={"state": state},
)
db_session.add(document)
await db_session.flush()
for content, position, embedding in chunks:
db_session.add(
Chunk(
content=content,
document_id=document.id,
position=position,
embedding=embedding,
)
)
await db_session.flush()
return document
async def test_keyword_relevant_document_is_retrieved(db_session, db_search_space):
document = await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Asyncio Guide",
chunks=[("The asyncio library enables concurrency.", 0, _axis(0))],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asyncio",
scope=SearchScope(),
top_k=5,
query_embedding=_axis(99),
)
assert document.id in {hit.document_id for hit in results}
async def test_semantically_closest_document_ranks_first(db_session, db_search_space):
aligned = await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Background Work",
chunks=[("Parallel execution of background work.", 0, _axis(0))],
)
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Dessert",
chunks=[("Recipes for chocolate cake.", 0, _axis(1))],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asynchronous coroutines",
scope=SearchScope(),
top_k=5,
query_embedding=_axis(0),
)
assert results[0].document_id == aligned.id
async def test_results_stay_within_the_search_space(db_session, db_search_space):
other_space = SearchSpace(name="Other Space", user_id=db_search_space.user_id)
db_session.add(other_space)
await db_session.flush()
mine = await _add_document(
db_session,
search_space_id=db_search_space.id,
chunks=[("Shared keyword asyncio here.", 0, _axis(0))],
)
foreign = await _add_document(
db_session,
search_space_id=other_space.id,
chunks=[("Shared keyword asyncio here.", 0, _axis(0))],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asyncio",
scope=SearchScope(),
top_k=5,
query_embedding=_axis(0),
)
found = {hit.document_id for hit in results}
assert mine.id in found and foreign.id not in found
async def test_document_ids_scope_pins_results(db_session, db_search_space):
pinned = await _add_document(
db_session,
search_space_id=db_search_space.id,
chunks=[("asyncio appears in the pinned doc.", 0, _axis(0))],
)
await _add_document(
db_session,
search_space_id=db_search_space.id,
chunks=[("asyncio appears in the other doc too.", 0, _axis(0))],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asyncio",
scope=SearchScope(document_ids=(pinned.id,)),
top_k=5,
query_embedding=_axis(0),
)
assert {hit.document_id for hit in results} == {pinned.id}
async def test_deleting_documents_are_excluded(db_session, db_search_space):
ready = await _add_document(
db_session,
search_space_id=db_search_space.id,
chunks=[("asyncio in a ready document.", 0, _axis(0))],
)
deleting = await _add_document(
db_session,
search_space_id=db_search_space.id,
state="deleting",
chunks=[("asyncio in a deleting document.", 0, _axis(0))],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asyncio",
scope=SearchScope(),
top_k=5,
query_embedding=_axis(0),
)
found = {hit.document_id for hit in results}
assert ready.id in found and deleting.id not in found
async def test_matched_chunks_are_ordered_for_reading(db_session, db_search_space):
# Insert out of order, and give the later-position chunk the stronger
# semantic score, so reading order differs from both insertion and score.
document = await _add_document(
db_session,
search_space_id=db_search_space.id,
chunks=[
("asyncio paragraph two.", 1, _axis(0)),
("asyncio paragraph one.", 0, _axis(50)),
],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asyncio",
scope=SearchScope(),
top_k=5,
query_embedding=_axis(0),
)
hit = next(hit for hit in results if hit.document_id == document.id)
assert [chunk.position for chunk in hit.chunks] == [0, 1]
async def test_top_k_caps_the_number_of_documents(db_session, db_search_space):
for index in range(3):
await _add_document(
db_session,
search_space_id=db_search_space.id,
title=f"Doc {index}",
chunks=[(f"asyncio mentioned in doc {index}.", 0, _axis(index))],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asyncio",
scope=SearchScope(),
top_k=2,
query_embedding=_axis(0),
)
assert len(results) == 2

View file

@ -3,7 +3,6 @@
from __future__ import annotations
import uuid
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from unittest.mock import MagicMock
@ -227,23 +226,6 @@ def patched_embed(monkeypatch):
return mock
@pytest.fixture
def patched_shielded_session(async_engine, monkeypatch):
"""Replace ``shielded_async_session`` in the knowledge_base module
with one that yields sessions from the test engine."""
test_maker = async_sessionmaker(async_engine, expire_on_commit=False)
@asynccontextmanager
async def _test_shielded():
async with test_maker() as session:
yield session
monkeypatch.setattr(
"app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.knowledge_base.shielded_async_session",
_test_shielded,
)
# ---------------------------------------------------------------------------
# Indexer test helpers
# ---------------------------------------------------------------------------

View file

@ -1,46 +0,0 @@
"""Integration test: _browse_recent_documents returns docs of multiple types.
Exercises the browse path (degenerate-query fallback) with a real PostgreSQL
database. Verifies that passing a list of document types correctly returns
documents of all listed types -- the same ``.in_()`` SQL path used by hybrid
search but through the browse/recency-ordered code path.
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.integration
async def test_browse_recent_documents_with_list_type_returns_both(
committed_google_data, patched_shielded_session
):
"""_browse_recent_documents returns docs of all types when given a list."""
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.knowledge_base import (
_browse_recent_documents,
)
space_id = committed_google_data["search_space_id"]
results = await _browse_recent_documents(
search_space_id=space_id,
document_type=["GOOGLE_DRIVE_FILE", "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"],
top_k=10,
start_date=None,
end_date=None,
)
returned_types = set()
for doc in results:
doc_info = doc.get("document", {})
dtype = doc_info.get("document_type")
if dtype:
returned_types.add(dtype)
assert "GOOGLE_DRIVE_FILE" in returned_types, (
"Native Drive docs should appear in browse results"
)
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in returned_types, (
"Legacy Composio Drive docs should appear in browse results"
)

View file

@ -1,61 +0,0 @@
"""Integration smoke tests for KB search query/date scoping."""
from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
import numpy as np
import pytest
from app.agents.chat.multi_agent_chat.shared.middleware import knowledge_search as ks
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
search_knowledge_base,
)
from .conftest import DUMMY_EMBEDDING
pytestmark = pytest.mark.integration
async def test_search_knowledge_base_applies_date_filters(
db_session,
seed_date_filtered_docs,
monkeypatch,
):
"""Date filters should remove older matching documents from scoped KB results."""
@asynccontextmanager
async def fake_shielded_async_session():
yield db_session
monkeypatch.setattr(ks, "shielded_async_session", fake_shielded_async_session)
monkeypatch.setattr(
ks, "embed_texts", lambda texts: [np.array(DUMMY_EMBEDDING) for _ in texts]
)
space_id = seed_date_filtered_docs["search_space"].id
recent_cutoff = datetime.now(UTC) - timedelta(days=30)
unfiltered_results = await search_knowledge_base(
query="ocv meeting decisions",
search_space_id=space_id,
available_document_types=["FILE"],
top_k=10,
)
filtered_results = await search_knowledge_base(
query="ocv meeting decisions",
search_space_id=space_id,
available_document_types=["FILE"],
top_k=10,
start_date=recent_cutoff,
end_date=datetime.now(UTC),
)
unfiltered_ids = {result["document"]["id"] for result in unfiltered_results}
filtered_ids = {result["document"]["id"] for result in filtered_results}
assert seed_date_filtered_docs["recent_doc"].id in unfiltered_ids
assert seed_date_filtered_docs["old_doc"].id in unfiltered_ids
assert seed_date_filtered_docs["recent_doc"].id in filtered_ids
assert seed_date_filtered_docs["old_doc"].id not in filtered_ids