Merge pull request #1539 from CREDO23/improve-chat-agent-context-and-citations

[FEAT] Unified [n] citation registry for KB + web, pull-based retrieval
This commit is contained in:
Rohan Verma 2026-06-25 13:34:52 -07:00 committed by GitHub
commit 94fdb8a113
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
160 changed files with 4097 additions and 5238 deletions

View file

@ -0,0 +1,237 @@
"""Behavior tests for the ``search_knowledge_base`` main-agent tool.
These exercise the tool through its public contract: seed a real document,
invoke the tool, and assert on the ``Command`` it returns the rendered
``<retrieved_context>`` carries ``[n]`` labels and the citation registry handed
back on state is populated.
The tool's own DB session is redirected to the test session, and the embedding
leg is pinned so the search is deterministic without a live model.
"""
from __future__ import annotations
import contextlib
import uuid
from types import SimpleNamespace
import pytest
from langchain_core.messages import ToolMessage
from langgraph.types import Command
from app.agents.chat.multi_agent_chat.main_agent.tools import search_knowledge_base
from app.agents.chat.multi_agent_chat.main_agent.tools.search_knowledge_base import (
create_search_knowledge_base_tool,
)
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
from app.config import config
from app.db import Chunk, Document, DocumentType, Folder
pytestmark = pytest.mark.integration
_DIM = config.embedding_model_instance.dimension
def _axis(index: int) -> list[float]:
vector = [0.0] * _DIM
vector[index] = 1.0
return vector
async def _add_document(
db_session,
*,
search_space_id: int,
title: str,
text: str,
folder_id: int | None = None,
):
document = Document(
title=title,
document_type=DocumentType.FILE,
content=text,
content_hash=uuid.uuid4().hex,
search_space_id=search_space_id,
folder_id=folder_id,
status={"state": "ready"},
)
db_session.add(document)
await db_session.flush()
db_session.add(
Chunk(content=text, document_id=document.id, position=0, embedding=_axis(0))
)
await db_session.flush()
return document
async def _add_folder(db_session, *, search_space_id: int, name: str = "Folder"):
folder = Folder(name=name, position="0", search_space_id=search_space_id)
db_session.add(folder)
await db_session.flush()
return folder
@pytest.fixture
def _tool_uses_test_session(db_session, monkeypatch):
"""Redirect the tool's ``shielded_async_session`` to the test transaction."""
@contextlib.asynccontextmanager
async def _session():
yield db_session
monkeypatch.setattr(search_knowledge_base, "shielded_async_session", _session)
@pytest.fixture
def _pinned_embedding(monkeypatch):
monkeypatch.setattr(
config.embedding_model_instance, "embed", lambda _query: _axis(0)
)
async def _invoke(tool, query: str, state: dict | None = None, context=None):
runtime = SimpleNamespace(
state=state or {}, tool_call_id="call-1", context=context
)
return await tool.coroutine(query, runtime)
def _mentions(*, document_ids=(), folder_ids=()):
return SimpleNamespace(
mentioned_document_ids=list(document_ids),
mentioned_folder_ids=list(folder_ids),
)
async def test_tool_returns_retrieved_context_with_numbered_passages(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Asyncio Guide",
text="The asyncio library enables concurrency.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(tool, "asyncio")
assert isinstance(result, Command)
message = result.update["messages"][0]
assert isinstance(message, ToolMessage)
assert "<retrieved_context>" in message.content
assert "[1]" in message.content
async def test_tool_populates_citation_registry_on_state(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Asyncio Guide",
text="The asyncio library enables concurrency.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(tool, "asyncio")
registry = result.update["citation_registry"]
assert isinstance(registry, CitationRegistry)
assert registry.by_n # at least one passage was registered as [n]
async def test_tool_reuses_existing_registry_numbering(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Asyncio Guide",
text="The asyncio library enables concurrency.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
first = await _invoke(tool, "asyncio")
carried = first.update["citation_registry"]
second = await _invoke(tool, "asyncio", state={"citation_registry": carried})
# Same passage searched twice keeps a single [n] (find-or-create).
assert len(second.update["citation_registry"].by_n) == 1
async def test_tool_reports_no_matches_without_touching_state(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(tool, "nonexistent-term-zzz")
assert isinstance(result, str)
assert "No knowledge-base matches" in result
async def test_tool_rejects_empty_query(
db_search_space, _tool_uses_test_session, _pinned_embedding
):
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(tool, " ")
assert isinstance(result, str)
assert "non-empty" in result
async def test_document_mention_confines_search_to_pinned_doc(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
pinned = await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Pinned",
text="asyncio appears in the pinned doc.",
)
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Other",
text="asyncio appears in the other doc.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(
tool, "asyncio", context=_mentions(document_ids=[pinned.id])
)
# Search is confined to the pinned doc: only its content is rendered.
content = result.update["messages"][0].content
assert "Pinned" in content
assert "Other" not in content
async def test_folder_mention_confines_search_to_folder_documents(
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
):
folder = await _add_folder(db_session, search_space_id=db_search_space.id)
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Inside",
text="asyncio appears inside the folder.",
folder_id=folder.id,
)
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Outside",
text="asyncio appears outside the folder.",
)
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
result = await _invoke(
tool, "asyncio", context=_mentions(folder_ids=[folder.id])
)
# Search is confined to the folder's document: only its content is rendered.
content = result.update["messages"][0].content
assert "Inside" in content
assert "Outside" not in content

View file

@ -0,0 +1,236 @@
"""Behavior tests for the hybrid chunk retriever against a real Postgres.
These exercise ``search_chunks`` through its public surface only: seed real
documents/chunks, run a search, and assert on the returned ``DocumentHit``s
never on SQL shape or internal ranking math. ``query_embedding`` is supplied
directly (a public parameter) so the semantic leg is deterministic instead of
depending on a live embedding model.
"""
from __future__ import annotations
import uuid
import pytest
from app.agents.chat.multi_agent_chat.shared.retrieval.hybrid_search import (
search_chunks,
)
from app.agents.chat.multi_agent_chat.shared.retrieval.models import SearchScope
from app.config import config
from app.db import Chunk, Document, DocumentType, SearchSpace
pytestmark = pytest.mark.integration
_DIM = config.embedding_model_instance.dimension
def _axis(index: int) -> list[float]:
"""A unit vector pointing along one axis — orthogonal axes are dissimilar."""
vector = [0.0] * _DIM
vector[index] = 1.0
return vector
async def _add_document(
db_session,
*,
search_space_id: int,
title: str = "Doc",
document_type: DocumentType = DocumentType.FILE,
state: str = "ready",
chunks: list[tuple[str, int, list[float]]],
) -> Document:
"""Persist one document and its chunks; ``chunks`` is (content, position, embedding)."""
document = Document(
title=title,
document_type=document_type,
content="\n".join(content for content, _, _ in chunks),
content_hash=uuid.uuid4().hex,
search_space_id=search_space_id,
status={"state": state},
)
db_session.add(document)
await db_session.flush()
for content, position, embedding in chunks:
db_session.add(
Chunk(
content=content,
document_id=document.id,
position=position,
embedding=embedding,
)
)
await db_session.flush()
return document
async def test_keyword_relevant_document_is_retrieved(db_session, db_search_space):
document = await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Asyncio Guide",
chunks=[("The asyncio library enables concurrency.", 0, _axis(0))],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asyncio",
scope=SearchScope(),
top_k=5,
query_embedding=_axis(99),
)
assert document.id in {hit.document_id for hit in results}
async def test_semantically_closest_document_ranks_first(db_session, db_search_space):
aligned = await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Background Work",
chunks=[("Parallel execution of background work.", 0, _axis(0))],
)
await _add_document(
db_session,
search_space_id=db_search_space.id,
title="Dessert",
chunks=[("Recipes for chocolate cake.", 0, _axis(1))],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asynchronous coroutines",
scope=SearchScope(),
top_k=5,
query_embedding=_axis(0),
)
assert results[0].document_id == aligned.id
async def test_results_stay_within_the_search_space(db_session, db_search_space):
other_space = SearchSpace(name="Other Space", user_id=db_search_space.user_id)
db_session.add(other_space)
await db_session.flush()
mine = await _add_document(
db_session,
search_space_id=db_search_space.id,
chunks=[("Shared keyword asyncio here.", 0, _axis(0))],
)
foreign = await _add_document(
db_session,
search_space_id=other_space.id,
chunks=[("Shared keyword asyncio here.", 0, _axis(0))],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asyncio",
scope=SearchScope(),
top_k=5,
query_embedding=_axis(0),
)
found = {hit.document_id for hit in results}
assert mine.id in found and foreign.id not in found
async def test_document_ids_scope_pins_results(db_session, db_search_space):
pinned = await _add_document(
db_session,
search_space_id=db_search_space.id,
chunks=[("asyncio appears in the pinned doc.", 0, _axis(0))],
)
await _add_document(
db_session,
search_space_id=db_search_space.id,
chunks=[("asyncio appears in the other doc too.", 0, _axis(0))],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asyncio",
scope=SearchScope(document_ids=(pinned.id,)),
top_k=5,
query_embedding=_axis(0),
)
assert {hit.document_id for hit in results} == {pinned.id}
async def test_deleting_documents_are_excluded(db_session, db_search_space):
ready = await _add_document(
db_session,
search_space_id=db_search_space.id,
chunks=[("asyncio in a ready document.", 0, _axis(0))],
)
deleting = await _add_document(
db_session,
search_space_id=db_search_space.id,
state="deleting",
chunks=[("asyncio in a deleting document.", 0, _axis(0))],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asyncio",
scope=SearchScope(),
top_k=5,
query_embedding=_axis(0),
)
found = {hit.document_id for hit in results}
assert ready.id in found and deleting.id not in found
async def test_matched_chunks_are_ordered_for_reading(db_session, db_search_space):
# Insert out of order, and give the later-position chunk the stronger
# semantic score, so reading order differs from both insertion and score.
document = await _add_document(
db_session,
search_space_id=db_search_space.id,
chunks=[
("asyncio paragraph two.", 1, _axis(0)),
("asyncio paragraph one.", 0, _axis(50)),
],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asyncio",
scope=SearchScope(),
top_k=5,
query_embedding=_axis(0),
)
hit = next(hit for hit in results if hit.document_id == document.id)
assert [chunk.position for chunk in hit.chunks] == [0, 1]
async def test_top_k_caps_the_number_of_documents(db_session, db_search_space):
for index in range(3):
await _add_document(
db_session,
search_space_id=db_search_space.id,
title=f"Doc {index}",
chunks=[(f"asyncio mentioned in doc {index}.", 0, _axis(index))],
)
results = await search_chunks(
db_session,
search_space_id=db_search_space.id,
query="asyncio",
scope=SearchScope(),
top_k=2,
query_embedding=_axis(0),
)
assert len(results) == 2

View file

@ -3,7 +3,6 @@
from __future__ import annotations
import uuid
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from unittest.mock import MagicMock
@ -227,23 +226,6 @@ def patched_embed(monkeypatch):
return mock
@pytest.fixture
def patched_shielded_session(async_engine, monkeypatch):
"""Replace ``shielded_async_session`` in the knowledge_base module
with one that yields sessions from the test engine."""
test_maker = async_sessionmaker(async_engine, expire_on_commit=False)
@asynccontextmanager
async def _test_shielded():
async with test_maker() as session:
yield session
monkeypatch.setattr(
"app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.knowledge_base.shielded_async_session",
_test_shielded,
)
# ---------------------------------------------------------------------------
# Indexer test helpers
# ---------------------------------------------------------------------------

View file

@ -1,46 +0,0 @@
"""Integration test: _browse_recent_documents returns docs of multiple types.
Exercises the browse path (degenerate-query fallback) with a real PostgreSQL
database. Verifies that passing a list of document types correctly returns
documents of all listed types -- the same ``.in_()`` SQL path used by hybrid
search but through the browse/recency-ordered code path.
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.integration
async def test_browse_recent_documents_with_list_type_returns_both(
committed_google_data, patched_shielded_session
):
"""_browse_recent_documents returns docs of all types when given a list."""
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.knowledge_base import (
_browse_recent_documents,
)
space_id = committed_google_data["search_space_id"]
results = await _browse_recent_documents(
search_space_id=space_id,
document_type=["GOOGLE_DRIVE_FILE", "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"],
top_k=10,
start_date=None,
end_date=None,
)
returned_types = set()
for doc in results:
doc_info = doc.get("document", {})
dtype = doc_info.get("document_type")
if dtype:
returned_types.add(dtype)
assert "GOOGLE_DRIVE_FILE" in returned_types, (
"Native Drive docs should appear in browse results"
)
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in returned_types, (
"Legacy Composio Drive docs should appear in browse results"
)

View file

@ -1,61 +0,0 @@
"""Integration smoke tests for KB search query/date scoping."""
from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
import numpy as np
import pytest
from app.agents.chat.multi_agent_chat.shared.middleware import knowledge_search as ks
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
search_knowledge_base,
)
from .conftest import DUMMY_EMBEDDING
pytestmark = pytest.mark.integration
async def test_search_knowledge_base_applies_date_filters(
db_session,
seed_date_filtered_docs,
monkeypatch,
):
"""Date filters should remove older matching documents from scoped KB results."""
@asynccontextmanager
async def fake_shielded_async_session():
yield db_session
monkeypatch.setattr(ks, "shielded_async_session", fake_shielded_async_session)
monkeypatch.setattr(
ks, "embed_texts", lambda texts: [np.array(DUMMY_EMBEDDING) for _ in texts]
)
space_id = seed_date_filtered_docs["search_space"].id
recent_cutoff = datetime.now(UTC) - timedelta(days=30)
unfiltered_results = await search_knowledge_base(
query="ocv meeting decisions",
search_space_id=space_id,
available_document_types=["FILE"],
top_k=10,
)
filtered_results = await search_knowledge_base(
query="ocv meeting decisions",
search_space_id=space_id,
available_document_types=["FILE"],
top_k=10,
start_date=recent_cutoff,
end_date=datetime.now(UTC),
)
unfiltered_ids = {result["document"]["id"] for result in unfiltered_results}
filtered_ids = {result["document"]["id"] for result in filtered_results}
assert seed_date_filtered_docs["recent_doc"].id in unfiltered_ids
assert seed_date_filtered_docs["old_doc"].id in unfiltered_ids
assert seed_date_filtered_docs["recent_doc"].id in filtered_ids
assert seed_date_filtered_docs["old_doc"].id not in filtered_ids

View file

@ -0,0 +1,41 @@
"""Tests for connector pointer field selection."""
from __future__ import annotations
import pytest
from app.agents.chat.runtime.references.connectors import connector_pointer_fields
pytestmark = pytest.mark.unit
def test_prefers_chip_account_and_type() -> None:
label, provider = connector_pointer_fields(
account_name="work@acme.com",
connector_type="Gmail",
fallback_name="My Gmail",
)
assert (label, provider) == ("work@acme.com", "Gmail")
def test_falls_back_to_stored_name_when_account_missing() -> None:
label, provider = connector_pointer_fields(
account_name=None,
connector_type="Slack",
fallback_name="Acme Slack",
)
assert label == "Acme Slack"
assert provider == "Slack"
def test_provider_is_none_when_unknown() -> None:
label, provider = connector_pointer_fields(
account_name="a@b.com",
connector_type=None,
fallback_name=None,
)
assert label == "a@b.com"
assert provider is None

View file

@ -0,0 +1,21 @@
"""Tests for folder pointer-path shaping."""
from __future__ import annotations
import pytest
from app.agents.chat.runtime.references.folders import folder_pointer_path
pytestmark = pytest.mark.unit
def test_adds_trailing_slash_so_path_reads_as_directory() -> None:
assert folder_pointer_path(7, {7: "/documents/Specs"}) == "/documents/Specs/"
def test_keeps_existing_trailing_slash() -> None:
assert folder_pointer_path(7, {7: "/documents/Specs/"}) == "/documents/Specs/"
def test_unknown_folder_falls_back_to_documents_root() -> None:
assert folder_pointer_path(99, {}) == "/documents/"

View file

@ -0,0 +1,93 @@
"""Tests for reference pointer rendering."""
from __future__ import annotations
import pytest
from app.agents.chat.runtime.references import (
ChatReference,
ConnectorReference,
DocumentReference,
FolderReference,
render_reference_pointers,
)
pytestmark = pytest.mark.unit
def test_returns_none_when_no_references() -> None:
assert render_reference_pointers([]) is None
def test_wraps_block_and_keeps_reference_order() -> None:
block = render_reference_pointers(
[
DocumentReference(entity_id=42, label="Q3 Notes", path="/documents/q3.xml"),
ChatReference(entity_id=5, label="Pricing"),
]
)
assert block is not None
assert block.startswith("<referenced_this_turn>")
assert block.endswith("</referenced_this_turn>")
assert block.index("document 42") < block.index("chat 5")
def test_document_with_path_shows_title_and_path() -> None:
block = render_reference_pointers(
[
DocumentReference(
entity_id=42,
label="Q3 Launch Notes",
path="/documents/Launch/Q3.xml",
)
]
)
assert block is not None
assert '- document 42 — "Q3 Launch Notes" (/documents/Launch/Q3.xml)' in block
def test_folder_with_path_renders_with_folder_kind() -> None:
block = render_reference_pointers(
[FolderReference(entity_id=7, label="Specs", path="/documents/Specs/")]
)
assert block is not None
assert '- folder 7 — "Specs" (/documents/Specs/)' in block
def test_connector_shows_provider_and_account() -> None:
block = render_reference_pointers(
[ConnectorReference(entity_id=12, label="work@acme.com", provider="Gmail")]
)
assert block is not None
assert "- connector 12 — Gmail (work@acme.com)" in block
def test_connector_without_provider_falls_back_to_label() -> None:
block = render_reference_pointers(
[ConnectorReference(entity_id=12, label="work@acme.com")]
)
assert block is not None
assert "- connector 12 — work@acme.com" in block
def test_chat_shows_quoted_title() -> None:
block = render_reference_pointers(
[ChatReference(entity_id=5, label="Pricing debate")]
)
assert block is not None
assert '- chat 5 — "Pricing debate"' in block
def test_label_whitespace_is_collapsed_to_one_line() -> None:
block = render_reference_pointers(
[DocumentReference(entity_id=1, label="line one\nline two", path="/d.xml")]
)
assert block is not None
assert '- document 1 — "line one line two"' in block

View file

@ -0,0 +1,93 @@
"""Tests for the shared ``web_search`` tool's citable-result adaptation.
The tool's network path (SearXNG + live connectors) is out of scope here; these
cover the pure mapping from raw web results to renderable, citable documents and
the end-to-end registration of ``WEB_RESULT`` ``[n]`` labels.
"""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.agents.chat.multi_agent_chat.shared.document_render import render_web_results
from app.agents.chat.shared.tools.web_search import (
_to_renderable_web_documents,
_web_source_label,
)
pytestmark = pytest.mark.unit
def _raw_result(url: str, title: str, content: str) -> dict:
return {
"document": {"title": title, "metadata": {"url": url}},
"content": content,
}
def test_web_source_label_strips_scheme_and_www() -> None:
assert _web_source_label("https://www.example.com/path") == "Web · example.com"
assert _web_source_label("http://news.site.org/a/b") == "Web · news.site.org"
assert _web_source_label("") == "Web"
def test_adapter_maps_each_result_to_one_web_passage() -> None:
docs = _to_renderable_web_documents(
[
_raw_result("https://a.com/x", "Alpha", "alpha body"),
_raw_result("https://b.com/y", "Beta", "beta body"),
]
)
assert [d.title for d in docs] == ["Alpha", "Beta"]
passages = [p for d in docs for p in d.passages]
assert all(p.source_type is CitationSourceType.WEB_RESULT for p in passages)
assert passages[0].locator == {"url": "https://a.com/x"}
assert passages[0].content == "alpha body"
def test_adapter_skips_results_without_url_or_content() -> None:
docs = _to_renderable_web_documents(
[
_raw_result("", "No URL", "has content"),
_raw_result("https://c.com/z", "Empty", " "),
_raw_result("https://d.com/w", "Good", "real content"),
]
)
assert [d.title for d in docs] == ["Good"]
def test_adapter_truncates_on_char_budget() -> None:
big = "x" * 30
docs = _to_renderable_web_documents(
[
_raw_result("https://a.com", "A", big),
_raw_result("https://b.com", "B", big),
_raw_result("https://c.com", "C", big),
],
max_chars=50,
)
# First fits (30), second crosses 50 and stops the loop.
assert [d.title for d in docs] == ["A"]
def test_end_to_end_registers_web_results_for_citation() -> None:
registry = CitationRegistry()
docs = _to_renderable_web_documents(
[_raw_result("https://example.com/a", "Example", "the answer is 42")]
)
block = render_web_results(docs, registry)
assert block is not None
assert "[1] the answer is 42" in block
entry = registry.resolve(1)
assert entry is not None
assert entry.source_type is CitationSourceType.WEB_RESULT
assert entry.locator == {"url": "https://example.com/a"}

View file

@ -0,0 +1,49 @@
"""Tests for citation-entry → frontend payload mapping."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations.markers import (
to_frontend_payload,
)
from app.agents.chat.multi_agent_chat.shared.citations.models import (
CitationEntry,
CitationSourceType,
)
pytestmark = pytest.mark.unit
def _entry(source_type: CitationSourceType, locator: dict) -> CitationEntry:
return CitationEntry(n=1, source_type=source_type, locator=locator)
def test_kb_chunk_maps_to_chunk_id() -> None:
entry = _entry(CitationSourceType.KB_CHUNK, {"chunk_id": 42, "document_id": 7})
assert to_frontend_payload(entry) == "42"
def test_anon_chunk_keeps_negative_id() -> None:
entry = _entry(CitationSourceType.ANON_CHUNK, {"chunk_id": -3})
assert to_frontend_payload(entry) == "-3"
def test_web_result_maps_to_url() -> None:
entry = _entry(CitationSourceType.WEB_RESULT, {"url": "https://example.com/a"})
assert to_frontend_payload(entry) == "https://example.com/a"
def test_not_yet_renderable_kind_is_dropped() -> None:
entry = _entry(CitationSourceType.CHAT_TURN, {"thread_id": 1, "turn": 2})
assert to_frontend_payload(entry) is None
def test_missing_locator_field_is_dropped() -> None:
entry = _entry(CitationSourceType.KB_CHUNK, {})
assert to_frontend_payload(entry) is None

View file

@ -0,0 +1,113 @@
"""Tests for rewriting model ``[n]`` ordinals into frontend citation markers."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations.models import CitationSourceType
from app.agents.chat.multi_agent_chat.shared.citations.normalizer import (
normalize_citations,
)
from app.agents.chat.multi_agent_chat.shared.citations.registry import CitationRegistry
pytestmark = pytest.mark.unit
def _registry_with_chunks(*chunk_ids: int) -> CitationRegistry:
registry = CitationRegistry()
for chunk_id in chunk_ids:
registry.register(CitationSourceType.KB_CHUNK, {"chunk_id": chunk_id})
return registry
def test_single_ordinal_is_rewritten() -> None:
registry = _registry_with_chunks(42)
assert normalize_citations("We shipped it [1].", registry) == (
"We shipped it [citation:42]."
)
def test_adjacent_brackets_are_each_rewritten() -> None:
registry = _registry_with_chunks(42, 7)
assert normalize_citations("Both agree [1][2].", registry) == (
"Both agree [citation:42][citation:7]."
)
def test_comma_separated_brackets_are_each_rewritten() -> None:
registry = _registry_with_chunks(42, 7)
assert normalize_citations("Both agree [1], [2].", registry) == (
"Both agree [citation:42], [citation:7]."
)
def test_unknown_ordinal_is_dropped() -> None:
registry = _registry_with_chunks(42)
assert normalize_citations("Maybe [9] is real.", registry) == "Maybe is real."
def test_unknown_ordinal_among_known_is_dropped() -> None:
registry = _registry_with_chunks(42)
assert normalize_citations("See [1][9].", registry) == "See [citation:42]."
def test_web_result_rewrites_to_url() -> None:
registry = CitationRegistry()
registry.register(CitationSourceType.WEB_RESULT, {"url": "https://example.com"})
assert normalize_citations("Per the docs [1].", registry) == (
"Per the docs [citation:https://example.com]."
)
def test_word_glued_citation_is_rewritten() -> None:
# The model frequently writes citations glued to the preceding word
# (``docs[1]``); these must still resolve to a marker, not leak as raw text.
registry = _registry_with_chunks(42)
assert normalize_citations("verifying against docs[1].", registry) == (
"verifying against docs[citation:42]."
)
def test_word_glued_unknown_ordinal_drops() -> None:
# A glued ordinal that doesn't resolve drops harmlessly (no broken marker,
# no raw ``[n]`` leak) rather than being preserved as array-index syntax.
registry = _registry_with_chunks(42)
assert normalize_citations("see notes[9] later", registry) == "see notes later"
def test_array_index_inside_code_is_left_alone() -> None:
# Genuine array/index syntax is protected by the code-region carve-out.
registry = _registry_with_chunks(42)
assert normalize_citations("Read `arr[1]` carefully.", registry) == (
"Read `arr[1]` carefully."
)
def test_ordinals_inside_inline_code_are_untouched() -> None:
registry = _registry_with_chunks(42)
assert normalize_citations("Use `list[1]` here [1].", registry) == (
"Use `list[1]` here [citation:42]."
)
def test_ordinals_inside_fenced_code_are_untouched() -> None:
registry = _registry_with_chunks(42)
text = "Before [1].\n```\nx = a[1]\n```\nAfter [1]."
assert normalize_citations(text, registry) == (
"Before [citation:42].\n```\nx = a[1]\n```\nAfter [citation:42]."
)
def test_empty_text_is_returned_unchanged() -> None:
assert normalize_citations("", _registry_with_chunks(42)) == ""

View file

@ -0,0 +1,174 @@
"""Unit tests for the citation registry spine."""
from __future__ import annotations
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
make_key,
)
def test_register_assigns_monotonic_labels() -> None:
registry = CitationRegistry()
first = registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 42, "chunk_id": 880}
)
second = registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 42, "chunk_id": 881}
)
assert (first, second) == (1, 2)
assert registry.next_n == 3
def test_register_is_find_or_create_for_same_unit() -> None:
registry = CitationRegistry()
locator = {"document_id": 42, "chunk_id": 880}
first = registry.register(CitationSourceType.KB_CHUNK, locator)
again = registry.register(CitationSourceType.KB_CHUNK, locator)
assert first == again == 1
assert len(registry.by_n) == 1
assert registry.next_n == 2
def test_dedup_is_insensitive_to_locator_key_order() -> None:
registry = CitationRegistry()
first = registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 42, "chunk_id": 880}
)
reordered = registry.register(
CitationSourceType.KB_CHUNK, {"chunk_id": 880, "document_id": 42}
)
assert first == reordered
def test_same_locator_values_across_types_do_not_collide() -> None:
registry = CitationRegistry()
chunk = registry.register(CitationSourceType.KB_CHUNK, {"id": 7})
chat = registry.register(CitationSourceType.CHAT_TURN, {"id": 7})
assert chunk != chat
def test_resolve_returns_entry_with_locator_and_display() -> None:
registry = CitationRegistry()
n = registry.register(
CitationSourceType.WEB_RESULT,
{"url": "https://example.com"},
{"title": "Example"},
)
entry = registry.resolve(n)
assert entry is not None
assert entry.n == n
assert entry.source_type is CitationSourceType.WEB_RESULT
assert entry.locator == {"url": "https://example.com"}
assert entry.display == {"title": "Example"}
def test_resolve_unknown_label_returns_none() -> None:
registry = CitationRegistry()
assert registry.resolve(999) is None
def test_registry_round_trips_through_serialization() -> None:
registry = CitationRegistry()
registry.register(
CitationSourceType.KB_CHUNK,
{"document_id": 42, "chunk_id": 880},
{"title": "Q3 Launch Notes"},
)
restored = CitationRegistry.model_validate(registry.model_dump())
entry = restored.resolve(1)
assert entry is not None
assert entry.source_type is CitationSourceType.KB_CHUNK
assert restored.next_n == registry.next_n
def test_make_key_is_stable_and_type_prefixed() -> None:
key_a = make_key(CitationSourceType.KB_CHUNK, {"document_id": 42, "chunk_id": 880})
key_b = make_key(CitationSourceType.KB_CHUNK, {"chunk_id": 880, "document_id": 42})
assert key_a == key_b
assert key_a.startswith("kb_chunk|")
def _kb(registry: CitationRegistry, chunk_id: int) -> int:
return registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
)
def test_merge_unions_disjoint_registries_preserving_labels() -> None:
left = CitationRegistry()
_kb(left, 10) # [1]
_kb(left, 11) # [2]
# A branch that forked from `left`, then registered its own chunk at [3].
right = left.model_copy(deep=True)
third = _kb(right, 12) # [3]
assert third == 3
merged = left.merge(right)
assert merged.resolve(1).locator["chunk_id"] == 10
assert merged.resolve(2).locator["chunk_id"] == 11
assert merged.resolve(3).locator["chunk_id"] == 12
assert merged.next_n == 4
def test_merge_keeps_one_label_for_a_shared_source() -> None:
left = CitationRegistry()
_kb(left, 10) # [1]
right = CitationRegistry()
_kb(right, 10) # also [1], same source
merged = left.merge(right)
assert len(merged.by_n) == 1
assert merged.resolve(1).locator["chunk_id"] == 10
assert merged.next_n == 2
def test_merge_remints_on_collision_without_losing_sources() -> None:
# Two branches forked from the same base [1], each minting a *different*
# source at [2]. Merge must keep both sources, re-minting one.
base = CitationRegistry()
_kb(base, 10) # [1]
left = base.model_copy(deep=True)
_kb(left, 11) # [2] -> chunk 11
right = base.model_copy(deep=True)
_kb(right, 12) # [2] -> chunk 12 (collision)
merged = left.merge(right)
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
assert chunk_ids == {10, 11, 12}
assert merged.resolve(2).locator["chunk_id"] == 11 # left wins the slot
assert merged.resolve(3).locator["chunk_id"] == 12 # right re-minted
assert merged.next_n == 4
def test_merge_does_not_mutate_inputs() -> None:
left = CitationRegistry()
_kb(left, 10)
right = CitationRegistry()
_kb(right, 11)
left.merge(right)
assert list(left.by_n) == [1]
assert list(right.by_n) == [1]

View file

@ -0,0 +1,152 @@
"""Tests for the shared ``render_document`` (one ``<document>`` block)."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.agents.chat.multi_agent_chat.shared.document_render import (
RenderableDocument,
RenderablePassage,
render_document,
)
pytestmark = pytest.mark.unit
def _document(
document_id: int,
title: str,
chunk_ids: list[int],
*,
source: str | None = None,
) -> RenderableDocument:
return RenderableDocument(
title=title,
source=source,
passages=[
RenderablePassage(
content=f"text {cid}",
locator={"document_id": document_id, "chunk_id": cid},
)
for cid in chunk_ids
],
)
def test_returns_none_when_no_passages() -> None:
registry = CitationRegistry()
assert (
render_document(_document(1, "Empty", []), view="excerpt", registry=registry)
is None
)
def test_excerpt_open_and_close_tags() -> None:
registry = CitationRegistry()
block = render_document(
_document(1, "Q3 Launch Notes", [880], source="Slack · #launch"),
view="excerpt",
registry=registry,
)
assert block is not None
assert block.startswith(
'<document title="Q3 Launch Notes" source="Slack · #launch" view="excerpt">'
)
assert block.endswith("</document>")
def test_full_view_renders_view_attribute() -> None:
registry = CitationRegistry()
block = render_document(_document(1, "Doc", [880]), view="full", registry=registry)
assert block is not None
assert '<document title="Doc" view="full">' in block
def test_source_attribute_omitted_when_absent() -> None:
registry = CitationRegistry()
block = render_document(
_document(1, "Plain", [1]), view="excerpt", registry=registry
)
assert block is not None
assert block.startswith('<document title="Plain" view="excerpt">')
def test_registers_passages_with_chunk_locators() -> None:
registry = CitationRegistry()
render_document(
_document(1, "Doc", [880], source="Slack"),
view="excerpt",
registry=registry,
)
entry = registry.resolve(1)
assert entry is not None
assert entry.source_type is CitationSourceType.KB_CHUNK
assert entry.locator == {"document_id": 1, "chunk_id": 880}
assert entry.display == {"title": "Doc", "source": "Slack"}
def test_passages_get_monotonic_labels() -> None:
registry = CitationRegistry()
block = render_document(
_document(1, "Doc", [880, 881]), view="excerpt", registry=registry
)
assert block is not None
assert " [1] text 880" in block
assert " [2] text 881" in block
def test_multiline_passage_indents_under_label() -> None:
registry = CitationRegistry()
document = RenderableDocument(
title="Doc",
passages=[
RenderablePassage(
content="line one\nline two",
locator={"document_id": 1, "chunk_id": 5},
)
],
)
block = render_document(document, view="excerpt", registry=registry)
assert block is not None
assert " [1] line one\n line two" in block
def test_attribute_values_are_escaped() -> None:
registry = CitationRegistry()
block = render_document(
_document(1, 'A & B <c> "d"', [1], source="x & y"),
view="excerpt",
registry=registry,
)
assert block is not None
assert 'title="A &amp; B &lt;c&gt; &quot;d&quot;"' in block
assert 'source="x &amp; y"' in block
def test_same_passage_reuses_label_across_calls() -> None:
registry = CitationRegistry()
document = _document(1, "Doc", [880])
render_document(document, view="excerpt", registry=registry)
render_document(document, view="full", registry=registry)
assert registry.next_n == 2

View file

@ -0,0 +1,94 @@
"""Tests for the ``<retrieved_context>`` wrapper around excerpt documents."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
from app.agents.chat.multi_agent_chat.shared.document_render import (
RenderableDocument,
RenderablePassage,
render_search_context,
)
pytestmark = pytest.mark.unit
def _document(
document_id: int,
title: str,
chunk_ids: list[int],
*,
source: str | None = None,
) -> RenderableDocument:
return RenderableDocument(
title=title,
source=source,
passages=[
RenderablePassage(
content=f"text {cid}",
locator={"document_id": document_id, "chunk_id": cid},
)
for cid in chunk_ids
],
)
def test_returns_none_when_nothing_to_show() -> None:
registry = CitationRegistry()
assert render_search_context([], registry) is None
assert render_search_context([_document(1, "Empty", [])], registry) is None
def test_assigns_monotonic_labels_across_documents() -> None:
registry = CitationRegistry()
block = render_search_context(
[
_document(1, "Q3 Launch Notes", [880, 881], source="Slack"),
_document(2, "Timeline", [12], source="Notion"),
],
registry,
)
assert block is not None
assert "[1] text 880" in block
assert "[2] text 881" in block
assert "[3] text 12" in block
def test_wraps_in_retrieved_context_and_teaches_excerpt_and_citation() -> None:
registry = CitationRegistry()
block = render_search_context([_document(1, "Doc", [1])], registry)
assert block is not None
assert block.startswith("<retrieved_context>")
assert block.endswith("</retrieved_context>")
assert "excerpt view" in block
assert "Cite a chunk with its [n]." in block
def test_documents_render_as_excerpt_blocks() -> None:
registry = CitationRegistry()
block = render_search_context(
[_document(1, "Q3", [1], source="Slack · #launch")], registry
)
assert block is not None
assert '<document title="Q3" source="Slack · #launch" view="excerpt">' in block
assert "</document>" in block
def test_same_passage_reuses_label_across_calls() -> None:
registry = CitationRegistry()
document = _document(1, "Doc", [880])
render_search_context([document], registry)
block = render_search_context([document], registry)
assert block is not None
assert "[1] text 880" in block
assert registry.next_n == 2

View file

@ -0,0 +1,35 @@
"""Tests for building a document's source label."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.document_render import source_label
pytestmark = pytest.mark.unit
def test_known_type_uses_friendly_name() -> None:
assert source_label("SLACK_CONNECTOR", {}) == "Slack"
def test_unmapped_type_is_prettified() -> None:
assert source_label("GOOGLE_DRIVE_FILE", {}) == "Google Drive"
def test_url_host_is_appended_and_www_stripped() -> None:
label = source_label("CRAWLED_URL", {"url": "https://www.docs.python.org/3/"})
assert label == "Web · docs.python.org"
def test_host_only_when_type_unknown() -> None:
assert source_label(None, {"url": "https://example.com/a"}) == "example.com"
def test_returns_none_when_nothing_known() -> None:
assert source_label(None, {}) is None
def test_non_http_url_is_ignored() -> None:
assert source_label("FILE", {"url": "/local/path"}) == "File"

View file

@ -0,0 +1,82 @@
"""Tests for the ``<web_results>`` wrapper around web-result excerpt documents."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.agents.chat.multi_agent_chat.shared.document_render import (
RenderableDocument,
RenderablePassage,
render_web_results,
)
pytestmark = pytest.mark.unit
def _web_doc(url: str, title: str, content: str) -> RenderableDocument:
return RenderableDocument(
title=title,
source=f"Web · {url.split('//', 1)[-1].split('/', 1)[0]}",
passages=[
RenderablePassage(
content=content,
locator={"url": url},
source_type=CitationSourceType.WEB_RESULT,
)
],
)
def test_returns_none_when_nothing_to_show() -> None:
registry = CitationRegistry()
assert render_web_results([], registry) is None
def test_wraps_in_web_results_container() -> None:
registry = CitationRegistry()
block = render_web_results(
[_web_doc("https://example.com/a", "Example", "the answer is 42")],
registry,
)
assert block is not None
assert block.startswith("<web_results>")
assert block.endswith("</web_results>")
assert "cite a result with its [n]" in block
assert '<document title="Example" source="Web · example.com" view="excerpt">' in block
assert "[1] the answer is 42" in block
def test_registers_each_result_as_web_result_with_url_locator() -> None:
registry = CitationRegistry()
render_web_results(
[
_web_doc("https://a.com/x", "A", "alpha"),
_web_doc("https://b.com/y", "B", "beta"),
],
registry,
)
first = registry.resolve(1)
second = registry.resolve(2)
assert first is not None and second is not None
assert first.source_type is CitationSourceType.WEB_RESULT
assert first.locator == {"url": "https://a.com/x"}
assert second.locator == {"url": "https://b.com/y"}
def test_same_url_reuses_label_across_calls() -> None:
registry = CitationRegistry()
doc = _web_doc("https://example.com/a", "Example", "stable fact")
render_web_results([doc], registry)
render_web_results([doc], registry)
assert registry.next_n == 2

View file

@ -0,0 +1,51 @@
"""Tests for mapping a DocumentHit to a renderable document."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.retrieval.adapter import (
to_renderable_document,
)
from app.agents.chat.multi_agent_chat.shared.retrieval.models import (
ChunkHit,
DocumentHit,
)
pytestmark = pytest.mark.unit
def test_maps_identity_source_and_passages() -> None:
hit = DocumentHit(
document_id=42,
title="Q3 Launch Notes",
document_type="SLACK_CONNECTOR",
metadata={},
score=0.9,
chunks=[
ChunkHit(chunk_id=880, content="a", position=4, score=0.9),
ChunkHit(chunk_id=881, content="b", position=7, score=0.5),
],
)
document = to_renderable_document(hit)
assert document.title == "Q3 Launch Notes"
assert document.source == "Slack"
assert [
(p.locator["chunk_id"], p.content) for p in document.passages
] == [(880, "a"), (881, "b")]
assert all(p.locator["document_id"] == 42 for p in document.passages)
def test_document_with_no_chunks_maps_to_no_passages() -> None:
hit = DocumentHit(
document_id=1,
title="Empty",
document_type=None,
metadata={},
score=0.0,
chunks=[],
)
assert to_renderable_document(hit).passages == []

View file

@ -0,0 +1,65 @@
"""Tests for the build_context pipeline (rerank → adapt → render)."""
from __future__ import annotations
from typing import Any
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
from app.agents.chat.multi_agent_chat.shared.retrieval.models import (
ChunkHit,
DocumentHit,
)
from app.agents.chat.multi_agent_chat.shared.retrieval.service import build_context
pytestmark = pytest.mark.unit
def _hit(document_id: int, chunk_id: int) -> DocumentHit:
return DocumentHit(
document_id=document_id,
title=f"Doc {document_id}",
document_type="FILE",
metadata={},
score=1.0 / document_id,
chunks=[ChunkHit(chunk_id=chunk_id, content=f"text {chunk_id}", position=0, score=1.0)],
)
def test_no_hits_renders_nothing() -> None:
assert build_context("q", [], CitationRegistry()) is None
def test_renders_block_and_registers_labels_in_order() -> None:
registry = CitationRegistry()
block = build_context("q", [_hit(1, 880), _hit(2, 12)], registry)
assert block is not None
assert "[1] text 880" in block
assert "[2] text 12" in block
assert registry.resolve(1).locator == {"document_id": 1, "chunk_id": 880}
assert registry.resolve(2).locator == {"document_id": 2, "chunk_id": 12}
class _ReverseReranker:
"""Stand-in reranker that simply reverses document order."""
def rerank_documents(
self, query_text: str, documents: list[dict[str, Any]]
) -> list[dict[str, Any]]:
return list(reversed(documents))
def test_reranker_reorders_documents_before_labeling() -> None:
registry = CitationRegistry()
block = build_context(
"q", [_hit(1, 880), _hit(2, 12)], registry, reranker=_ReverseReranker()
)
assert block is not None
# Reversed: doc 2 now renders first and gets [1].
assert registry.resolve(1).locator == {"document_id": 2, "chunk_id": 12}
assert registry.resolve(2).locator == {"document_id": 1, "chunk_id": 880}

View file

@ -1,295 +0,0 @@
"""Tests for the prompt fragment composer."""
from __future__ import annotations
from datetime import UTC, datetime
import pytest
from app.db import ChatVisibility
from app.prompts.system_prompt_composer.composer import (
ALL_TOOL_NAMES_ORDERED,
compose_system_prompt,
detect_provider_variant,
)
pytestmark = pytest.mark.unit
@pytest.fixture
def fixed_today() -> datetime:
return datetime(2025, 6, 1, 12, 0, tzinfo=UTC)
class TestProviderVariantDetection:
@pytest.mark.parametrize(
"model_name,expected",
[
# GPT-4 family routes to "classic" (autonomous-persistence style)
("openai:gpt-4o-mini", "openai_classic"),
("openai:gpt-4-turbo", "openai_classic"),
# GPT-5 / o-series route to "reasoning" (channel-aware pragmatic)
("openai:gpt-5", "openai_reasoning"),
("openai:o1-preview", "openai_reasoning"),
("openai:o3-mini", "openai_reasoning"),
# Codex family beats reasoning (more specific). Mirrors OpenCode
# ``system.ts`` — ``gpt-*-codex`` gets the code-purist prompt.
("openai:gpt-5-codex", "openai_codex"),
("openai:gpt-codex", "openai_codex"),
("openai:codex-mini", "openai_codex"),
# Anthropic + Google
("anthropic:claude-3-5-sonnet", "anthropic"),
("anthropic/claude-opus-4", "anthropic"),
("google:gemini-2.0-flash", "google"),
("vertex:gemini-1.5-pro", "google"),
# Newly-covered families
("moonshot:kimi-k2", "kimi"),
("openrouter:moonshot/kimi-k2.5", "kimi"),
("xai:grok-2", "grok"),
("openrouter:x-ai/grok-3", "grok"),
("openai:deepseek-v3", "deepseek"),
("deepseek:deepseek-r1", "deepseek"),
# Unknown families fall back to default (no provider block emitted)
("groq:mixtral-8x7b", "default"),
("together:llama-3.1-70b", "default"),
(None, "default"),
("", "default"),
],
)
def test_detection(self, model_name: str | None, expected: str) -> None:
assert detect_provider_variant(model_name) == expected
def test_codex_takes_precedence_over_reasoning(self) -> None:
"""Regression guard: ``gpt-5-codex`` must NOT match the generic
``gpt-5`` reasoning regex first. Codex is the more specialised
prompt and mirrors OpenCode's dispatch order.
"""
from app.prompts.system_prompt_composer.composer import detect_provider_variant
assert detect_provider_variant("openai:gpt-5-codex") == "openai_codex"
assert detect_provider_variant("openai:gpt-5") == "openai_reasoning"
class TestCompose:
def test_default_prompt_has_required_blocks(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(today=fixed_today)
# System instruction wrapper
assert "<system_instruction>" in prompt
assert "</system_instruction>" in prompt
# Date interpolated
assert "2025-06-01" in prompt
# Core policy blocks present
assert "<knowledge_base_only_policy>" in prompt
assert "<tool_routing>" in prompt
assert "<parameter_resolution>" in prompt
assert "<memory_protocol>" in prompt
# Tools
assert "<tools>" in prompt
assert "</tools>" in prompt
# Citations on by default
assert "<citation_instructions>" in prompt
assert "[citation:chunk_id]" in prompt
def test_team_visibility_uses_team_variants(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(
today=fixed_today,
thread_visibility=ChatVisibility.SEARCH_SPACE,
)
# Team-specific phrasing in the agent block
assert "team space" in prompt
# Memory protocol mentions team
assert "team" in prompt
# Should NOT mention the user-only memory phrasing
assert "personal knowledge base" not in prompt
def test_private_visibility_uses_private_variants(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
thread_visibility=ChatVisibility.PRIVATE,
)
assert "personal knowledge base" in prompt
# Should NOT mention the team-specific phrasing about prefixed authors
assert "[DisplayName of the author]" not in prompt
def test_citations_disabled_swaps_block(self, fixed_today: datetime) -> None:
prompt_on = compose_system_prompt(today=fixed_today, citations_enabled=True)
prompt_off = compose_system_prompt(today=fixed_today, citations_enabled=False)
assert "Citations are DISABLED" in prompt_off
assert "Citations are DISABLED" not in prompt_on
assert "[citation:chunk_id]" in prompt_on
def test_enabled_tool_filter_only_includes_listed_tools(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"web_search", "scrape_webpage"},
)
assert "web_search:" in prompt or "- web_search:" in prompt
assert "scrape_webpage:" in prompt or "- scrape_webpage:" in prompt
# Excluded tools should NOT appear in tool listing
assert "generate_podcast:" not in prompt
assert "generate_image:" not in prompt
def test_disabled_tool_note_is_appended(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"web_search"},
disabled_tool_names={"generate_image", "generate_podcast"},
)
assert "DISABLED TOOLS (by user):" in prompt
assert "Generate Image" in prompt
assert "Generate Podcast" in prompt
def test_mcp_routing_block_emits_when_provided(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(
today=fixed_today,
mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]},
)
assert "<mcp_tool_routing>" in prompt
assert "My GitLab" in prompt
assert "gitlab_search" in prompt
def test_mcp_routing_block_absent_when_no_servers(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={})
assert "<mcp_tool_routing>" not in prompt
def test_provider_block_renders_when_anthropic(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(
today=fixed_today, model_name="anthropic:claude-3-5-sonnet"
)
assert "<provider_hints>" in prompt
assert "Anthropic" in prompt or "Claude" in prompt
def test_provider_block_absent_for_default(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(today=fixed_today, model_name="custom:foo")
assert "<provider_hints>" not in prompt
@pytest.mark.parametrize(
"model_name,expected_marker",
[
# Each marker is a unique-ish phrase from the corresponding fragment.
# If a fragment is renamed/rewritten such that the marker is gone,
# update both the fragment and this test deliberately.
("openai:gpt-5-codex", "Codex-class"),
("openai:gpt-5", "OpenAI reasoning model"),
("openai:gpt-4o", "classic OpenAI chat model"),
("anthropic:claude-3-5-sonnet", "Anthropic Claude"),
("google:gemini-2.0-flash", "Google Gemini"),
("moonshot:kimi-k2", "Moonshot Kimi"),
("xai:grok-2", "xAI Grok"),
("deepseek:deepseek-r1", "DeepSeek"),
],
)
def test_each_known_variant_renders_with_its_marker(
self,
fixed_today: datetime,
model_name: str,
expected_marker: str,
) -> None:
"""Every supported variant must produce a ``<provider_hints>`` block
containing its identifying marker. This pins the dispatch + the
on-disk fragments together so a missing/renamed file is caught
immediately.
"""
prompt = compose_system_prompt(today=fixed_today, model_name=model_name)
assert "<provider_hints>" in prompt, (
f"variant for {model_name!r} did not emit a provider_hints block; "
"the corresponding providers/<variant>.md may be missing"
)
assert expected_marker in prompt, (
f"variant for {model_name!r} emitted hints but lacked the "
f"expected marker {expected_marker!r} — the fragment may have "
"drifted from the dispatch table"
)
def test_provider_blocks_are_byte_stable_across_calls(
self, fixed_today: datetime
) -> None:
"""Cache-stability guard: same model id → byte-identical prompt."""
a = compose_system_prompt(today=fixed_today, model_name="moonshot:kimi-k2")
b = compose_system_prompt(today=fixed_today, model_name="moonshot:kimi-k2")
assert a == b
def test_custom_system_instructions_override_default(
self, fixed_today: datetime
) -> None:
custom = "You are a custom assistant. Today is {resolved_today}."
prompt = compose_system_prompt(
today=fixed_today, custom_system_instructions=custom
)
assert "You are a custom assistant. Today is 2025-06-01." in prompt
# Default block should NOT be present
assert "<knowledge_base_only_policy>" not in prompt
def test_provider_hints_render_with_custom_system_instructions(
self, fixed_today: datetime
) -> None:
"""Regression guard for the always-append decision: provider hints
append AFTER a custom system prompt.
Provider hints are stylistic nudges (parallel tool-call rules,
formatting guidance, etc.) that help the model regardless of
what the system instructions say. Suppressing them when a
custom prompt is set would partially defeat the per-family
prompt machinery.
"""
prompt = compose_system_prompt(
today=fixed_today,
custom_system_instructions="You are a custom assistant.",
model_name="anthropic/claude-3-5-sonnet",
)
assert "You are a custom assistant." in prompt
assert "<provider_hints>" in prompt
# The custom prompt must come BEFORE the provider hints so the
# user's framing isn't drowned out by the stylistic nudges.
assert prompt.index("You are a custom assistant.") < prompt.index(
"<provider_hints>"
)
def test_use_default_false_with_no_custom_yields_no_system_block(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
use_default_system_instructions=False,
)
# No system_instruction wrapper but tools/citations still emitted
assert "<system_instruction>" not in prompt
assert "<tools>" in prompt
def test_all_known_tools_have_fragments(self) -> None:
# Soft assertion: verify that every tool in the canonical order
# produces non-empty content for at least one variant.
for tool in ALL_TOOL_NAMES_ORDERED:
prompt = compose_system_prompt(
today=datetime(2025, 1, 1, tzinfo=UTC),
enabled_tool_names={tool},
)
assert tool in prompt, f"tool {tool!r} missing from composed prompt"
class TestStableOrderingForCacheStability:
"""Regression guard: prompt cache hit-rate depends on byte-stable prefix."""
def test_composition_is_deterministic_given_same_inputs(
self, fixed_today: datetime
) -> None:
a = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"web_search", "scrape_webpage"},
mcp_connector_tools={"X": ["x_a", "x_b"]},
)
b = compose_system_prompt(
today=fixed_today,
enabled_tool_names={
"scrape_webpage",
"web_search",
}, # set order shouldn't matter
mcp_connector_tools={"X": ["x_a", "x_b"]},
)
assert a == b

View file

@ -38,7 +38,7 @@ class TestIsProtectedSystemMessage:
)
def test_tolerates_leading_whitespace(self) -> None:
msg = SystemMessage(content=" \n<priority_documents>\n...")
msg = SystemMessage(content=" \n<workspace_tree>\n...")
assert _is_protected_system_message(msg) is True
@ -89,7 +89,7 @@ class TestPartitionMessages:
def test_protected_system_message_preserved_even_in_summarize_half(self) -> None:
partitioner = self._build_partitioner()
protected = SystemMessage(content="<priority_documents>\n...")
protected = SystemMessage(content="<workspace_tree>\n...")
msgs = [
HumanMessage(content="old human"),
AIMessage(content="old ai"),

View file

@ -28,7 +28,6 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
"SURFSENSE_ENABLE_SKILLS",
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"SURFSENSE_ENABLE_ACTION_LOG",
"SURFSENSE_ENABLE_REVERT_ROUTE",
"SURFSENSE_ENABLE_PLUGIN_LOADER",
@ -57,7 +56,6 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) ->
assert flags.enable_llm_tool_selector is False
assert flags.enable_skills is True
assert flags.enable_specialized_subagents is True
assert flags.enable_kb_planner_runnable is True
assert flags.enable_action_log is True
assert flags.enable_revert_route is True
assert flags.enable_plugin_loader is False
@ -122,7 +120,6 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
"enable_llm_tool_selector": "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
"enable_skills": "SURFSENSE_ENABLE_SKILLS",
"enable_specialized_subagents": "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",

View file

@ -90,8 +90,8 @@ class TestSubstituteInText:
class TestResolveMentions:
"""``resolve_mentions`` resolves chip ids → virtual paths and emits
a ``ResolvedMentionSet`` whose id partitions feed
``KnowledgePriorityMiddleware``."""
a ``ResolvedMentionSet`` whose id partitions feed the
``search_knowledge_base`` retrieval scope."""
@pytest.mark.asyncio
async def test_returns_empty_when_no_mentions(self):

View file

@ -4,9 +4,14 @@ from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.agents.chat.multi_agent_chat.shared.state.reducers import (
_CLEAR,
_add_unique_reducer,
_citation_registry_merge_reducer,
_dict_merge_with_tombstones_reducer,
_initial_filesystem_state,
_list_append_reducer,
@ -93,6 +98,57 @@ class TestDictMergeWithTombstones:
}
def _kb_registry(chunk_id: int) -> CitationRegistry:
registry = CitationRegistry()
registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
)
return registry
class TestCitationRegistryMergeReducer:
def test_none_left_returns_right(self):
right = _kb_registry(10)
assert _citation_registry_merge_reducer(None, right) is right
def test_none_right_returns_left(self):
left = _kb_registry(10)
assert _citation_registry_merge_reducer(left, None) is left
def test_both_none_returns_none(self):
assert _citation_registry_merge_reducer(None, None) is None
def test_unions_two_registries(self):
left = _kb_registry(10)
right = _kb_registry(11)
merged = _citation_registry_merge_reducer(left, right)
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
assert chunk_ids == {10, 11}
def test_coerces_serialized_dict_update(self):
# The checkpointer serializes Command.update via ormsgpack before the
# reducer runs, so `right` can arrive as a plain dict.
left = _kb_registry(10)
right = _kb_registry(11).model_dump()
merged = _citation_registry_merge_reducer(left, right)
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
assert chunk_ids == {10, 11}
def test_coerces_both_sides_from_dict(self):
left = _kb_registry(10).model_dump()
right = _kb_registry(11).model_dump()
merged = _citation_registry_merge_reducer(left, right)
assert isinstance(merged, CitationRegistry)
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
assert chunk_ids == {10, 11}
class TestInitialFilesystemState:
def test_default_shape(self):
state = _initial_filesystem_state()
@ -105,8 +161,6 @@ class TestInitialFilesystemState:
assert state["doc_id_by_path"] == {}
assert state["dirty_paths"] == []
assert state["dirty_path_tool_calls"] == {}
assert state["kb_priority"] == []
assert state["kb_matched_chunk_ids"] == {}
assert state["kb_anon_doc"] is None
assert state["tree_version"] == 0

View file

@ -0,0 +1,124 @@
"""Unit tests for the KB read path: full-view render + anonymous-doc loading.
DB-backed loads are exercised by the integration suite; here we lock the pure
pieces ``render_full_document`` and the anonymous-upload branch of
``aload_document`` which need no database.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.agents.chat.multi_agent_chat.shared.document_render import (
RenderableDocument,
RenderablePassage,
)
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import (
KBPostgresBackend,
render_full_document,
)
pytestmark = pytest.mark.unit
def _backend(state: dict) -> KBPostgresBackend:
return KBPostgresBackend(search_space_id=1, runtime=SimpleNamespace(state=state))
def test_render_full_document_uses_full_view_and_registers() -> None:
registry = CitationRegistry()
document = RenderableDocument(
title="Launch Notes",
source="Slack",
passages=[
RenderablePassage(
content="push to March 10",
locator={"document_id": 7, "chunk_id": 880},
),
],
)
rendered = render_full_document(document, registry)
assert '<document title="Launch Notes" source="Slack" view="full">' in rendered
assert "[1] push to March 10" in rendered
entry = registry.resolve(1)
assert entry is not None
assert entry.locator == {"document_id": 7, "chunk_id": 880}
def test_render_full_document_reuses_search_label() -> None:
"""A chunk already registered from search keeps its [n] on a later full read."""
registry = CitationRegistry()
n = registry.register(
CitationSourceType.KB_CHUNK,
{"document_id": 7, "chunk_id": 880},
{"title": "Launch Notes", "source": "Slack"},
)
document = RenderableDocument(
title="Launch Notes",
source="Slack",
passages=[
RenderablePassage(
content="new chunk",
locator={"document_id": 7, "chunk_id": 881},
),
RenderablePassage(
content="push to March 10",
locator={"document_id": 7, "chunk_id": 880},
),
],
)
rendered = render_full_document(document, registry)
assert f"[{n}] push to March 10" in rendered
assert "[2] new chunk" in rendered
def test_render_full_document_empty_falls_back_to_notice() -> None:
registry = CitationRegistry()
document = RenderableDocument(title="Empty", passages=[])
assert render_full_document(document, registry) == (
"(This document has no readable content.)"
)
async def test_aload_document_anonymous_upload() -> None:
backend = _backend(
{
"kb_anon_doc": {
"path": "/anon_upload.md",
"title": "Quarterly Report",
"chunks": [
{"chunk_id": -1, "content": "revenue grew"},
{"chunk_id": -2, "content": "costs fell"},
],
}
}
)
loaded = await backend.aload_document("/anon_upload.md")
assert loaded is not None
document, doc_id = loaded
assert doc_id is None
assert document.title == "Quarterly Report"
assert [p.locator["chunk_id"] for p in document.passages] == [-1, -2]
assert all(p.locator["document_id"] == -1 for p in document.passages)
assert all(
p.source_type is CitationSourceType.ANON_CHUNK for p in document.passages
)
async def test_aload_document_unknown_path_returns_none() -> None:
backend = _backend({})
assert await backend.aload_document("/not/under/documents.md") is None

View file

@ -1,689 +0,0 @@
"""Unit tests for knowledge_search middleware helpers."""
import json
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from app.agents.chat.multi_agent_chat.shared.middleware import knowledge_search as ks
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.document_xml import (
build_document_xml as _build_document_xml,
)
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
KBSearchPlan,
KnowledgePriorityMiddleware,
_normalize_optional_date_range,
_parse_kb_search_plan_response,
_render_recent_conversation,
_resolve_search_types,
)
pytestmark = pytest.mark.unit
# ── _resolve_search_types ──────────────────────────────────────────────
class TestResolveSearchTypes:
def test_returns_none_when_no_inputs(self):
assert _resolve_search_types(None, None) is None
def test_returns_none_when_both_empty(self):
assert _resolve_search_types([], []) is None
def test_includes_legacy_type_for_google_gmail(self):
result = _resolve_search_types(["GOOGLE_GMAIL_CONNECTOR"], None)
assert "GOOGLE_GMAIL_CONNECTOR" in result
assert "COMPOSIO_GMAIL_CONNECTOR" in result
def test_includes_legacy_type_for_google_drive(self):
result = _resolve_search_types(None, ["GOOGLE_DRIVE_FILE"])
assert "GOOGLE_DRIVE_FILE" in result
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in result
def test_includes_legacy_type_for_google_calendar(self):
result = _resolve_search_types(["GOOGLE_CALENDAR_CONNECTOR"], None)
assert "GOOGLE_CALENDAR_CONNECTOR" in result
assert "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR" in result
def test_no_legacy_expansion_for_unrelated_types(self):
result = _resolve_search_types(["FILE", "NOTE"], None)
assert set(result) == {"FILE", "NOTE"}
def test_combines_connectors_and_document_types(self):
result = _resolve_search_types(["FILE"], ["NOTE", "CRAWLED_URL"])
assert {"FILE", "NOTE", "CRAWLED_URL"}.issubset(set(result))
def test_deduplicates(self):
result = _resolve_search_types(["FILE", "FILE"], ["FILE"])
assert result.count("FILE") == 1
# ── _build_document_xml ────────────────────────────────────────────────
class TestBuildDocumentXml:
@pytest.fixture
def sample_document(self):
return {
"document_id": 42,
"document": {
"id": 42,
"document_type": "FILE",
"title": "Test Doc",
"metadata": {"url": "https://example.com"},
},
"chunks": [
{"chunk_id": 101, "content": "First chunk content"},
{"chunk_id": 102, "content": "Second chunk content"},
{"chunk_id": 103, "content": "Third chunk content"},
],
}
def test_contains_document_metadata(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<document_id>42</document_id>" in xml
assert "<document_type>FILE</document_type>" in xml
assert "Test Doc" in xml
def test_contains_chunk_index(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<chunk_index>" in xml
assert "</chunk_index>" in xml
assert 'chunk_id="101"' in xml
assert 'chunk_id="102"' in xml
assert 'chunk_id="103"' in xml
def test_matched_chunks_flagged_in_index(self, sample_document):
xml = _build_document_xml(sample_document, matched_chunk_ids={101, 103})
lines = xml.split("\n")
for line in lines:
if 'chunk_id="101"' in line:
assert 'matched="true"' in line
if 'chunk_id="102"' in line:
assert 'matched="true"' not in line
if 'chunk_id="103"' in line:
assert 'matched="true"' in line
def test_chunk_content_in_document_content_section(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<document_content>" in xml
assert "First chunk content" in xml
assert "Second chunk content" in xml
assert "Third chunk content" in xml
def test_line_numbers_in_chunk_index_are_accurate(self, sample_document):
"""Verify that the line ranges in chunk_index actually point to the right content."""
xml = _build_document_xml(sample_document, matched_chunk_ids={101})
xml_lines = xml.split("\n")
for line in xml_lines:
if 'chunk_id="101"' in line and "lines=" in line:
import re
m = re.search(r'lines="(\d+)-(\d+)"', line)
assert m, f"No lines= attribute found in: {line}"
start, _end = int(m.group(1)), int(m.group(2))
target_line = xml_lines[start - 1]
assert "101" in target_line
assert "First chunk content" in target_line
break
else:
pytest.fail("chunk_id=101 entry not found in chunk_index")
def test_splits_into_lines_correctly(self, sample_document):
"""Each chunk occupies exactly one line (no embedded newlines)."""
xml = _build_document_xml(sample_document)
lines = xml.split("\n")
chunk_lines = [
line for line in lines if "<![CDATA[" in line and "<chunk" in line
]
assert len(chunk_lines) == 3
# ── planner parsing / date normalization ───────────────────────────────
class TestPlannerHelpers:
def test_parse_kb_search_plan_response_accepts_plain_json(self):
plan = _parse_kb_search_plan_response(
json.dumps(
{
"optimized_query": "ocv meeting decisions summary",
"start_date": "2026-03-01",
"end_date": "2026-03-31",
}
)
)
assert plan.optimized_query == "ocv meeting decisions summary"
assert plan.start_date == "2026-03-01"
assert plan.end_date == "2026-03-31"
def test_parse_kb_search_plan_response_accepts_fenced_json(self):
plan = _parse_kb_search_plan_response(
"""```json
{"optimized_query":"deel founders guide","start_date":null,"end_date":null}
```"""
)
assert plan.optimized_query == "deel founders guide"
assert plan.start_date is None
assert plan.end_date is None
def test_normalize_optional_date_range_returns_none_when_absent(self):
start_date, end_date = _normalize_optional_date_range(None, None)
assert start_date is None
assert end_date is None
def test_normalize_optional_date_range_resolves_single_bound(self):
start_date, end_date = _normalize_optional_date_range("2026-03-01", None)
assert start_date is not None
assert end_date is not None
assert start_date.date().isoformat() == "2026-03-01"
assert end_date >= start_date
class FakeLLM:
def __init__(self, response_text: str):
self.response_text = response_text
self.calls: list[dict] = []
async def ainvoke(self, messages, config=None):
self.calls.append({"messages": messages, "config": config})
return AIMessage(content=self.response_text)
class FakeBudgetLLM:
def __init__(self, *, max_input_tokens: int):
self._max_input_tokens_value = max_input_tokens
def _get_max_input_tokens(self) -> int:
return self._max_input_tokens_value
def _count_tokens(self, messages) -> int:
# Deterministic, simple proxy for tests: count characters as tokens.
return sum(len(msg.get("content", "")) for msg in messages)
class TestKnowledgePriorityMiddlewarePlanner:
@pytest.fixture(autouse=True)
def _disable_planner_runnable(self, monkeypatch):
# ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the
# planner Runnable path is enabled) calls ``.bind()`` on the LLM,
# which the mock does not implement. Pin the flag off so the
# planner falls through to the legacy ``self.llm.ainvoke`` path
# these tests assert against (``llm.calls[0]["config"]``).
monkeypatch.setenv("SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "false")
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
messages = [
HumanMessage(content="old user context " * 40),
AIMessage(content="old assistant answer " * 35),
HumanMessage(content="recent user context " * 20),
AIMessage(content="recent assistant answer " * 18),
HumanMessage(content="latest question"),
]
rendered = _render_recent_conversation(
messages,
llm=FakeBudgetLLM(max_input_tokens=900),
user_text="latest question",
)
assert "recent user context" in rendered
assert "recent assistant answer" in rendered
assert "latest question" not in rendered
assert rendered.index("recent user context") < rendered.index(
"recent assistant answer"
)
def test_render_recent_conversation_falls_back_to_legacy_without_budgeting(self):
messages = [
HumanMessage(content="message one"),
AIMessage(content="message two"),
HumanMessage(content="latest question"),
]
rendered = _render_recent_conversation(
messages,
llm=None,
user_text="latest question",
)
assert "user: message one" in rendered
assert "assistant: message two" in rendered
assert "latest question" not in rendered
async def test_middleware_uses_optimized_query_and_dates(self, monkeypatch):
captured: dict = {}
async def fake_search_knowledge_base(**kwargs):
captured.update(kwargs)
return []
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
llm = FakeLLM(
json.dumps(
{
"optimized_query": "ocv meeting decisions action items",
"start_date": "2026-03-01",
"end_date": "2026-03-31",
}
)
)
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=37)
result = await middleware.abefore_agent(
{
"messages": [
HumanMessage(content="what happened in our OCV meeting last month?")
]
},
runtime=None,
)
assert result is not None
assert captured["query"] == "ocv meeting decisions action items"
assert captured["start_date"] is not None
assert captured["end_date"] is not None
assert captured["start_date"].date().isoformat() == "2026-03-01"
assert captured["end_date"].date().isoformat() == "2026-03-31"
assert llm.calls[0]["config"] == {"tags": ["surfsense:internal"]}
async def test_middleware_falls_back_when_planner_returns_invalid_json(
self,
monkeypatch,
):
captured: dict = {}
async def fake_search_knowledge_base(**kwargs):
captured.update(kwargs)
return []
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgePriorityMiddleware(
llm=FakeLLM("not json"),
search_space_id=37,
)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
runtime=None,
)
assert captured["query"] == "summarize founders guide by deel"
assert captured["start_date"] is None
assert captured["end_date"] is None
async def test_middleware_passes_none_dates_when_planner_returns_nulls(
self,
monkeypatch,
):
captured: dict = {}
async def fake_search_knowledge_base(**kwargs):
captured.update(kwargs)
return []
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgePriorityMiddleware(
llm=FakeLLM(
json.dumps(
{
"optimized_query": "deel founders guide summary",
"start_date": None,
"end_date": None,
}
)
),
search_space_id=37,
)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
runtime=None,
)
assert captured["query"] == "deel founders guide summary"
assert captured["start_date"] is None
assert captured["end_date"] is None
async def test_middleware_routes_to_recency_browse_when_flagged(
self,
monkeypatch,
):
"""When the planner sets is_recency_query=true, browse_recent_documents
is called instead of search_knowledge_base."""
browse_captured: dict = {}
search_called = False
async def fake_browse_recent_documents(**kwargs):
browse_captured.update(kwargs)
return []
async def fake_search_knowledge_base(**kwargs):
nonlocal search_called
search_called = True
return []
monkeypatch.setattr(
ks,
"browse_recent_documents",
fake_browse_recent_documents,
)
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
llm = FakeLLM(
json.dumps(
{
"optimized_query": "latest uploaded file",
"start_date": None,
"end_date": None,
"is_recency_query": True,
}
)
)
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42)
result = await middleware.abefore_agent(
{"messages": [HumanMessage(content="what's my latest file?")]},
runtime=None,
)
assert result is not None
assert browse_captured["search_space_id"] == 42
assert not search_called
async def test_middleware_uses_hybrid_search_when_not_recency(
self,
monkeypatch,
):
"""When is_recency_query is false (default), hybrid search is used."""
search_captured: dict = {}
browse_called = False
async def fake_browse_recent_documents(**kwargs):
nonlocal browse_called
browse_called = True
return []
async def fake_search_knowledge_base(**kwargs):
search_captured.update(kwargs)
return []
monkeypatch.setattr(
ks,
"browse_recent_documents",
fake_browse_recent_documents,
)
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
llm = FakeLLM(
json.dumps(
{
"optimized_query": "quarterly revenue report analysis",
"start_date": None,
"end_date": None,
"is_recency_query": False,
}
)
)
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="find the quarterly revenue report")]},
runtime=None,
)
assert search_captured["query"] == "quarterly revenue report analysis"
assert not browse_called
# ── KBSearchPlan schema ────────────────────────────────────────────────
class TestKBSearchPlanSchema:
def test_is_recency_query_defaults_to_false(self):
plan = KBSearchPlan(optimized_query="test query")
assert plan.is_recency_query is False
def test_is_recency_query_parses_true(self):
plan = _parse_kb_search_plan_response(
json.dumps(
{
"optimized_query": "latest uploaded file",
"start_date": None,
"end_date": None,
"is_recency_query": True,
}
)
)
assert plan.is_recency_query is True
assert plan.optimized_query == "latest uploaded file"
def test_missing_is_recency_query_defaults_to_false(self):
plan = _parse_kb_search_plan_response(
json.dumps(
{
"optimized_query": "meeting notes",
"start_date": None,
"end_date": None,
}
)
)
assert plan.is_recency_query is False
# ── mentioned_document_ids cross-turn drain ────────────────────────────
class TestKnowledgePriorityMentionDrain:
"""Regression tests for the cross-turn ``mentioned_document_ids`` drain.
The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware`
instance across turns of the same thread. ``mentioned_document_ids``
can therefore enter the middleware via two paths:
1. The constructor closure (``__init__(mentioned_document_ids=...)``)
seeded by the cache-miss build on turn 1.
2. ``runtime.context.mentioned_document_ids`` supplied freshly per
turn by the streaming task.
Without the drain fix, an empty ``runtime.context.mentioned_document_ids``
on turn 2 would fall through to the closure (because ``[]`` is falsy in
Python) and replay turn 1's mentions. This class pins down the
correct behaviour: the runtime path is authoritative even when empty,
and the closure is drained the first time the runtime path fires so
no later turn can ever resurrect stale state.
"""
@staticmethod
def _make_runtime(mention_ids: list[int]):
"""Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``."""
from types import SimpleNamespace
return SimpleNamespace(
context=SimpleNamespace(mentioned_document_ids=mention_ids),
)
@staticmethod
def _planner_llm() -> "FakeLLM":
# Planner returns a stable, non-recency plan so we always land in
# the hybrid-search branch (where ``fetch_mentioned_documents`` is
# invoked alongside the main search).
return FakeLLM(
json.dumps(
{
"optimized_query": "follow up question",
"start_date": None,
"end_date": None,
"is_recency_query": False,
}
)
)
async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch):
"""Turn 1 with mentions in BOTH closure and runtime context: the
runtime path wins AND the closure is drained so a future turn
cannot replay it.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
ks,
"fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgePriorityMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[1, 2, 3],
)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="what is in those docs?")]},
runtime=self._make_runtime([1, 2, 3]),
)
assert fetched_ids == [[1, 2, 3]], (
"runtime.context mentions must be the source of truth on turn 1"
)
assert middleware.mentioned_document_ids == [], (
"closure must be drained the first time the runtime path fires "
"so no later turn can replay stale mentions"
)
async def test_empty_runtime_context_does_not_replay_closure_mentions(
self, monkeypatch
):
"""Regression: turn 2 with NO mentions must not surface turn 1's
mentions from the constructor closure.
Before the fix, ``if ctx_mentions:`` treated an empty list as
absent and fell through to ``elif self.mentioned_document_ids:``,
replaying turn 1's mentions. This test pins down the corrected
behaviour.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
ks,
"fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
# Simulate a cached middleware instance whose closure was seeded
# by a previous turn's cache-miss build (mentions=[1,2,3]).
middleware = KnowledgePriorityMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[1, 2, 3],
)
# Turn 2: streaming task supplies an EMPTY mention list (no
# mentions on this follow-up turn).
await middleware.abefore_agent(
{"messages": [HumanMessage(content="what about the next steps?")]},
runtime=self._make_runtime([]),
)
assert fetched_ids == [], (
"fetch_mentioned_documents must NOT be called when the runtime "
"context says there are no mentions for this turn"
)
async def test_legacy_path_fires_only_when_runtime_context_absent(
self, monkeypatch
):
"""Backward-compat: if a caller doesn't supply runtime.context (old
non-streaming code path), the closure-injected mentions are still
honoured exactly once and then drained.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
ks,
"fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgePriorityMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[7, 8],
)
# First call: no runtime → legacy path uses the closure.
await middleware.abefore_agent(
{"messages": [HumanMessage(content="initial question")]},
runtime=None,
)
# Second call: still no runtime — closure already drained, so no replay.
await middleware.abefore_agent(
{"messages": [HumanMessage(content="follow up")]},
runtime=None,
)
assert fetched_ids == [[7, 8]], (
"legacy path must honour the closure exactly once and then drain it"
)
assert middleware.mentioned_document_ids == []

View file

@ -0,0 +1,85 @@
"""Behavior tests for finalize-time citation resolution.
The finalize step is the single server-side seam that turns the model's bare
``[n]`` ordinals into renderer-ready ``[citation:<payload>]`` markers, using the
registry captured from the run's final state. These tests pin that contract:
known ordinals resolve, unknown ones drop, foreign markers survive, and a
serialized (dict) registry is accepted just like a live one.
"""
from __future__ import annotations
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.tasks.chat.streaming.flows.shared.assistant_finalize import _resolve_citations
def _registry_with_chunk(chunk_id: int = 42) -> CitationRegistry:
registry = CitationRegistry()
registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
)
return registry
def _text(value: str) -> list[dict]:
return [{"type": "text", "text": value}]
def test_known_ordinal_resolves_to_chunk_marker():
payload = _resolve_citations(
_text("Launch is March 10 [1]."), _registry_with_chunk(42)
)
assert payload[0]["text"] == "Launch is March 10 [citation:42]."
def test_unknown_ordinal_is_dropped():
payload = _resolve_citations(
_text("Unsupported claim [9]."), _registry_with_chunk(42)
)
assert payload[0]["text"] == "Unsupported claim ."
def test_foreign_citation_marker_is_preserved():
payload = _resolve_citations(
_text("From the web [citation:https://example.com]."),
_registry_with_chunk(42),
)
assert payload[0]["text"] == "From the web [citation:https://example.com]."
def test_serialized_registry_is_accepted():
serialized = _registry_with_chunk(7).model_dump()
payload = _resolve_citations(_text("See [1]."), serialized)
assert payload[0]["text"] == "See [citation:7]."
def test_empty_registry_leaves_text_untouched():
payload = _resolve_citations(_text("No sources here [1]."), CitationRegistry())
assert payload[0]["text"] == "No sources here [1]."
def test_missing_registry_is_a_noop():
payload = _resolve_citations(_text("Nothing to resolve [1]."), None)
assert payload[0]["text"] == "Nothing to resolve [1]."
def test_non_text_parts_are_left_alone():
parts = [
{"type": "tool_call", "name": "search_knowledge_base", "args": {"q": "[1]"}},
{"type": "text", "text": "Result [1]."},
]
payload = _resolve_citations(parts, _registry_with_chunk(5))
assert payload[0]["args"]["q"] == "[1]"
assert payload[1]["text"] == "Result [citation:5]."