mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-28 21:49:40 +02:00
Merge pull request #1539 from CREDO23/improve-chat-agent-context-and-citations
[FEAT] Unified [n] citation registry for KB + web, pull-based retrieval
This commit is contained in:
commit
94fdb8a113
160 changed files with 4097 additions and 5238 deletions
|
|
@ -0,0 +1,237 @@
|
|||
"""Behavior tests for the ``search_knowledge_base`` main-agent tool.
|
||||
|
||||
These exercise the tool through its public contract: seed a real document,
|
||||
invoke the tool, and assert on the ``Command`` it returns — the rendered
|
||||
``<retrieved_context>`` carries ``[n]`` labels and the citation registry handed
|
||||
back on state is populated.
|
||||
The tool's own DB session is redirected to the test session, and the embedding
|
||||
leg is pinned so the search is deterministic without a live model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from app.agents.chat.multi_agent_chat.main_agent.tools import search_knowledge_base
|
||||
from app.agents.chat.multi_agent_chat.main_agent.tools.search_knowledge_base import (
|
||||
create_search_knowledge_base_tool,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
|
||||
from app.config import config
|
||||
from app.db import Chunk, Document, DocumentType, Folder
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_DIM = config.embedding_model_instance.dimension
|
||||
|
||||
|
||||
def _axis(index: int) -> list[float]:
|
||||
vector = [0.0] * _DIM
|
||||
vector[index] = 1.0
|
||||
return vector
|
||||
|
||||
|
||||
async def _add_document(
|
||||
db_session,
|
||||
*,
|
||||
search_space_id: int,
|
||||
title: str,
|
||||
text: str,
|
||||
folder_id: int | None = None,
|
||||
):
|
||||
document = Document(
|
||||
title=title,
|
||||
document_type=DocumentType.FILE,
|
||||
content=text,
|
||||
content_hash=uuid.uuid4().hex,
|
||||
search_space_id=search_space_id,
|
||||
folder_id=folder_id,
|
||||
status={"state": "ready"},
|
||||
)
|
||||
db_session.add(document)
|
||||
await db_session.flush()
|
||||
db_session.add(
|
||||
Chunk(content=text, document_id=document.id, position=0, embedding=_axis(0))
|
||||
)
|
||||
await db_session.flush()
|
||||
return document
|
||||
|
||||
|
||||
async def _add_folder(db_session, *, search_space_id: int, name: str = "Folder"):
|
||||
folder = Folder(name=name, position="0", search_space_id=search_space_id)
|
||||
db_session.add(folder)
|
||||
await db_session.flush()
|
||||
return folder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _tool_uses_test_session(db_session, monkeypatch):
|
||||
"""Redirect the tool's ``shielded_async_session`` to the test transaction."""
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _session():
|
||||
yield db_session
|
||||
|
||||
monkeypatch.setattr(search_knowledge_base, "shielded_async_session", _session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _pinned_embedding(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
config.embedding_model_instance, "embed", lambda _query: _axis(0)
|
||||
)
|
||||
|
||||
|
||||
async def _invoke(tool, query: str, state: dict | None = None, context=None):
|
||||
runtime = SimpleNamespace(
|
||||
state=state or {}, tool_call_id="call-1", context=context
|
||||
)
|
||||
return await tool.coroutine(query, runtime)
|
||||
|
||||
|
||||
def _mentions(*, document_ids=(), folder_ids=()):
|
||||
return SimpleNamespace(
|
||||
mentioned_document_ids=list(document_ids),
|
||||
mentioned_folder_ids=list(folder_ids),
|
||||
)
|
||||
|
||||
|
||||
async def test_tool_returns_retrieved_context_with_numbered_passages(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Asyncio Guide",
|
||||
text="The asyncio library enables concurrency.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(tool, "asyncio")
|
||||
|
||||
assert isinstance(result, Command)
|
||||
message = result.update["messages"][0]
|
||||
assert isinstance(message, ToolMessage)
|
||||
assert "<retrieved_context>" in message.content
|
||||
assert "[1]" in message.content
|
||||
|
||||
|
||||
async def test_tool_populates_citation_registry_on_state(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Asyncio Guide",
|
||||
text="The asyncio library enables concurrency.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(tool, "asyncio")
|
||||
|
||||
registry = result.update["citation_registry"]
|
||||
assert isinstance(registry, CitationRegistry)
|
||||
assert registry.by_n # at least one passage was registered as [n]
|
||||
|
||||
|
||||
async def test_tool_reuses_existing_registry_numbering(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Asyncio Guide",
|
||||
text="The asyncio library enables concurrency.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
first = await _invoke(tool, "asyncio")
|
||||
carried = first.update["citation_registry"]
|
||||
second = await _invoke(tool, "asyncio", state={"citation_registry": carried})
|
||||
|
||||
# Same passage searched twice keeps a single [n] (find-or-create).
|
||||
assert len(second.update["citation_registry"].by_n) == 1
|
||||
|
||||
|
||||
async def test_tool_reports_no_matches_without_touching_state(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(tool, "nonexistent-term-zzz")
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "No knowledge-base matches" in result
|
||||
|
||||
|
||||
async def test_tool_rejects_empty_query(
|
||||
db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(tool, " ")
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "non-empty" in result
|
||||
|
||||
|
||||
async def test_document_mention_confines_search_to_pinned_doc(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
pinned = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Pinned",
|
||||
text="asyncio appears in the pinned doc.",
|
||||
)
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Other",
|
||||
text="asyncio appears in the other doc.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(
|
||||
tool, "asyncio", context=_mentions(document_ids=[pinned.id])
|
||||
)
|
||||
|
||||
# Search is confined to the pinned doc: only its content is rendered.
|
||||
content = result.update["messages"][0].content
|
||||
assert "Pinned" in content
|
||||
assert "Other" not in content
|
||||
|
||||
|
||||
async def test_folder_mention_confines_search_to_folder_documents(
|
||||
db_session, db_search_space, _tool_uses_test_session, _pinned_embedding
|
||||
):
|
||||
folder = await _add_folder(db_session, search_space_id=db_search_space.id)
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Inside",
|
||||
text="asyncio appears inside the folder.",
|
||||
folder_id=folder.id,
|
||||
)
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Outside",
|
||||
text="asyncio appears outside the folder.",
|
||||
)
|
||||
tool = create_search_knowledge_base_tool(search_space_id=db_search_space.id)
|
||||
|
||||
result = await _invoke(
|
||||
tool, "asyncio", context=_mentions(folder_ids=[folder.id])
|
||||
)
|
||||
|
||||
# Search is confined to the folder's document: only its content is rendered.
|
||||
content = result.update["messages"][0].content
|
||||
assert "Inside" in content
|
||||
assert "Outside" not in content
|
||||
|
|
@ -0,0 +1,236 @@
|
|||
"""Behavior tests for the hybrid chunk retriever against a real Postgres.
|
||||
|
||||
These exercise ``search_chunks`` through its public surface only: seed real
|
||||
documents/chunks, run a search, and assert on the returned ``DocumentHit``s —
|
||||
never on SQL shape or internal ranking math. ``query_embedding`` is supplied
|
||||
directly (a public parameter) so the semantic leg is deterministic instead of
|
||||
depending on a live embedding model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.retrieval.hybrid_search import (
|
||||
search_chunks,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.retrieval.models import SearchScope
|
||||
from app.config import config
|
||||
from app.db import Chunk, Document, DocumentType, SearchSpace
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_DIM = config.embedding_model_instance.dimension
|
||||
|
||||
|
||||
def _axis(index: int) -> list[float]:
|
||||
"""A unit vector pointing along one axis — orthogonal axes are dissimilar."""
|
||||
vector = [0.0] * _DIM
|
||||
vector[index] = 1.0
|
||||
return vector
|
||||
|
||||
|
||||
async def _add_document(
|
||||
db_session,
|
||||
*,
|
||||
search_space_id: int,
|
||||
title: str = "Doc",
|
||||
document_type: DocumentType = DocumentType.FILE,
|
||||
state: str = "ready",
|
||||
chunks: list[tuple[str, int, list[float]]],
|
||||
) -> Document:
|
||||
"""Persist one document and its chunks; ``chunks`` is (content, position, embedding)."""
|
||||
document = Document(
|
||||
title=title,
|
||||
document_type=document_type,
|
||||
content="\n".join(content for content, _, _ in chunks),
|
||||
content_hash=uuid.uuid4().hex,
|
||||
search_space_id=search_space_id,
|
||||
status={"state": state},
|
||||
)
|
||||
db_session.add(document)
|
||||
await db_session.flush()
|
||||
for content, position, embedding in chunks:
|
||||
db_session.add(
|
||||
Chunk(
|
||||
content=content,
|
||||
document_id=document.id,
|
||||
position=position,
|
||||
embedding=embedding,
|
||||
)
|
||||
)
|
||||
await db_session.flush()
|
||||
return document
|
||||
|
||||
|
||||
async def test_keyword_relevant_document_is_retrieved(db_session, db_search_space):
|
||||
document = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Asyncio Guide",
|
||||
chunks=[("The asyncio library enables concurrency.", 0, _axis(0))],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asyncio",
|
||||
scope=SearchScope(),
|
||||
top_k=5,
|
||||
query_embedding=_axis(99),
|
||||
)
|
||||
|
||||
assert document.id in {hit.document_id for hit in results}
|
||||
|
||||
|
||||
async def test_semantically_closest_document_ranks_first(db_session, db_search_space):
|
||||
aligned = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Background Work",
|
||||
chunks=[("Parallel execution of background work.", 0, _axis(0))],
|
||||
)
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title="Dessert",
|
||||
chunks=[("Recipes for chocolate cake.", 0, _axis(1))],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asynchronous coroutines",
|
||||
scope=SearchScope(),
|
||||
top_k=5,
|
||||
query_embedding=_axis(0),
|
||||
)
|
||||
|
||||
assert results[0].document_id == aligned.id
|
||||
|
||||
|
||||
async def test_results_stay_within_the_search_space(db_session, db_search_space):
|
||||
other_space = SearchSpace(name="Other Space", user_id=db_search_space.user_id)
|
||||
db_session.add(other_space)
|
||||
await db_session.flush()
|
||||
|
||||
mine = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
chunks=[("Shared keyword asyncio here.", 0, _axis(0))],
|
||||
)
|
||||
foreign = await _add_document(
|
||||
db_session,
|
||||
search_space_id=other_space.id,
|
||||
chunks=[("Shared keyword asyncio here.", 0, _axis(0))],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asyncio",
|
||||
scope=SearchScope(),
|
||||
top_k=5,
|
||||
query_embedding=_axis(0),
|
||||
)
|
||||
|
||||
found = {hit.document_id for hit in results}
|
||||
assert mine.id in found and foreign.id not in found
|
||||
|
||||
|
||||
async def test_document_ids_scope_pins_results(db_session, db_search_space):
|
||||
pinned = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
chunks=[("asyncio appears in the pinned doc.", 0, _axis(0))],
|
||||
)
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
chunks=[("asyncio appears in the other doc too.", 0, _axis(0))],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asyncio",
|
||||
scope=SearchScope(document_ids=(pinned.id,)),
|
||||
top_k=5,
|
||||
query_embedding=_axis(0),
|
||||
)
|
||||
|
||||
assert {hit.document_id for hit in results} == {pinned.id}
|
||||
|
||||
|
||||
async def test_deleting_documents_are_excluded(db_session, db_search_space):
|
||||
ready = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
chunks=[("asyncio in a ready document.", 0, _axis(0))],
|
||||
)
|
||||
deleting = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
state="deleting",
|
||||
chunks=[("asyncio in a deleting document.", 0, _axis(0))],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asyncio",
|
||||
scope=SearchScope(),
|
||||
top_k=5,
|
||||
query_embedding=_axis(0),
|
||||
)
|
||||
|
||||
found = {hit.document_id for hit in results}
|
||||
assert ready.id in found and deleting.id not in found
|
||||
|
||||
|
||||
async def test_matched_chunks_are_ordered_for_reading(db_session, db_search_space):
|
||||
# Insert out of order, and give the later-position chunk the stronger
|
||||
# semantic score, so reading order differs from both insertion and score.
|
||||
document = await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
chunks=[
|
||||
("asyncio paragraph two.", 1, _axis(0)),
|
||||
("asyncio paragraph one.", 0, _axis(50)),
|
||||
],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asyncio",
|
||||
scope=SearchScope(),
|
||||
top_k=5,
|
||||
query_embedding=_axis(0),
|
||||
)
|
||||
|
||||
hit = next(hit for hit in results if hit.document_id == document.id)
|
||||
assert [chunk.position for chunk in hit.chunks] == [0, 1]
|
||||
|
||||
|
||||
async def test_top_k_caps_the_number_of_documents(db_session, db_search_space):
|
||||
for index in range(3):
|
||||
await _add_document(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
title=f"Doc {index}",
|
||||
chunks=[(f"asyncio mentioned in doc {index}.", 0, _axis(index))],
|
||||
)
|
||||
|
||||
results = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=db_search_space.id,
|
||||
query="asyncio",
|
||||
scope=SearchScope(),
|
||||
top_k=2,
|
||||
query_embedding=_axis(0),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
|
@ -3,7 +3,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
|
@ -227,23 +226,6 @@ def patched_embed(monkeypatch):
|
|||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_shielded_session(async_engine, monkeypatch):
|
||||
"""Replace ``shielded_async_session`` in the knowledge_base module
|
||||
with one that yields sessions from the test engine."""
|
||||
test_maker = async_sessionmaker(async_engine, expire_on_commit=False)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _test_shielded():
|
||||
async with test_maker() as session:
|
||||
yield session
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.knowledge_base.shielded_async_session",
|
||||
_test_shielded,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Indexer test helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
"""Integration test: _browse_recent_documents returns docs of multiple types.
|
||||
|
||||
Exercises the browse path (degenerate-query fallback) with a real PostgreSQL
|
||||
database. Verifies that passing a list of document types correctly returns
|
||||
documents of all listed types -- the same ``.in_()`` SQL path used by hybrid
|
||||
search but through the browse/recency-ordered code path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_browse_recent_documents_with_list_type_returns_both(
|
||||
committed_google_data, patched_shielded_session
|
||||
):
|
||||
"""_browse_recent_documents returns docs of all types when given a list."""
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.knowledge_base import (
|
||||
_browse_recent_documents,
|
||||
)
|
||||
|
||||
space_id = committed_google_data["search_space_id"]
|
||||
|
||||
results = await _browse_recent_documents(
|
||||
search_space_id=space_id,
|
||||
document_type=["GOOGLE_DRIVE_FILE", "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"],
|
||||
top_k=10,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
)
|
||||
|
||||
returned_types = set()
|
||||
for doc in results:
|
||||
doc_info = doc.get("document", {})
|
||||
dtype = doc_info.get("document_type")
|
||||
if dtype:
|
||||
returned_types.add(dtype)
|
||||
|
||||
assert "GOOGLE_DRIVE_FILE" in returned_types, (
|
||||
"Native Drive docs should appear in browse results"
|
||||
)
|
||||
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in returned_types, (
|
||||
"Legacy Composio Drive docs should appear in browse results"
|
||||
)
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
"""Integration smoke tests for KB search query/date scoping."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware import knowledge_search as ks
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
|
||||
search_knowledge_base,
|
||||
)
|
||||
|
||||
from .conftest import DUMMY_EMBEDDING
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
async def test_search_knowledge_base_applies_date_filters(
|
||||
db_session,
|
||||
seed_date_filtered_docs,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Date filters should remove older matching documents from scoped KB results."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_shielded_async_session():
|
||||
yield db_session
|
||||
|
||||
monkeypatch.setattr(ks, "shielded_async_session", fake_shielded_async_session)
|
||||
monkeypatch.setattr(
|
||||
ks, "embed_texts", lambda texts: [np.array(DUMMY_EMBEDDING) for _ in texts]
|
||||
)
|
||||
|
||||
space_id = seed_date_filtered_docs["search_space"].id
|
||||
recent_cutoff = datetime.now(UTC) - timedelta(days=30)
|
||||
|
||||
unfiltered_results = await search_knowledge_base(
|
||||
query="ocv meeting decisions",
|
||||
search_space_id=space_id,
|
||||
available_document_types=["FILE"],
|
||||
top_k=10,
|
||||
)
|
||||
filtered_results = await search_knowledge_base(
|
||||
query="ocv meeting decisions",
|
||||
search_space_id=space_id,
|
||||
available_document_types=["FILE"],
|
||||
top_k=10,
|
||||
start_date=recent_cutoff,
|
||||
end_date=datetime.now(UTC),
|
||||
)
|
||||
|
||||
unfiltered_ids = {result["document"]["id"] for result in unfiltered_results}
|
||||
filtered_ids = {result["document"]["id"] for result in filtered_results}
|
||||
|
||||
assert seed_date_filtered_docs["recent_doc"].id in unfiltered_ids
|
||||
assert seed_date_filtered_docs["old_doc"].id in unfiltered_ids
|
||||
assert seed_date_filtered_docs["recent_doc"].id in filtered_ids
|
||||
assert seed_date_filtered_docs["old_doc"].id not in filtered_ids
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
"""Tests for connector pointer field selection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.runtime.references.connectors import connector_pointer_fields
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_prefers_chip_account_and_type() -> None:
|
||||
label, provider = connector_pointer_fields(
|
||||
account_name="work@acme.com",
|
||||
connector_type="Gmail",
|
||||
fallback_name="My Gmail",
|
||||
)
|
||||
|
||||
assert (label, provider) == ("work@acme.com", "Gmail")
|
||||
|
||||
|
||||
def test_falls_back_to_stored_name_when_account_missing() -> None:
|
||||
label, provider = connector_pointer_fields(
|
||||
account_name=None,
|
||||
connector_type="Slack",
|
||||
fallback_name="Acme Slack",
|
||||
)
|
||||
|
||||
assert label == "Acme Slack"
|
||||
assert provider == "Slack"
|
||||
|
||||
|
||||
def test_provider_is_none_when_unknown() -> None:
|
||||
label, provider = connector_pointer_fields(
|
||||
account_name="a@b.com",
|
||||
connector_type=None,
|
||||
fallback_name=None,
|
||||
)
|
||||
|
||||
assert label == "a@b.com"
|
||||
assert provider is None
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""Tests for folder pointer-path shaping."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.runtime.references.folders import folder_pointer_path
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_adds_trailing_slash_so_path_reads_as_directory() -> None:
|
||||
assert folder_pointer_path(7, {7: "/documents/Specs"}) == "/documents/Specs/"
|
||||
|
||||
|
||||
def test_keeps_existing_trailing_slash() -> None:
|
||||
assert folder_pointer_path(7, {7: "/documents/Specs/"}) == "/documents/Specs/"
|
||||
|
||||
|
||||
def test_unknown_folder_falls_back_to_documents_root() -> None:
|
||||
assert folder_pointer_path(99, {}) == "/documents/"
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
"""Tests for reference pointer rendering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.runtime.references import (
|
||||
ChatReference,
|
||||
ConnectorReference,
|
||||
DocumentReference,
|
||||
FolderReference,
|
||||
render_reference_pointers,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_returns_none_when_no_references() -> None:
|
||||
assert render_reference_pointers([]) is None
|
||||
|
||||
|
||||
def test_wraps_block_and_keeps_reference_order() -> None:
|
||||
block = render_reference_pointers(
|
||||
[
|
||||
DocumentReference(entity_id=42, label="Q3 Notes", path="/documents/q3.xml"),
|
||||
ChatReference(entity_id=5, label="Pricing"),
|
||||
]
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert block.startswith("<referenced_this_turn>")
|
||||
assert block.endswith("</referenced_this_turn>")
|
||||
assert block.index("document 42") < block.index("chat 5")
|
||||
|
||||
|
||||
def test_document_with_path_shows_title_and_path() -> None:
|
||||
block = render_reference_pointers(
|
||||
[
|
||||
DocumentReference(
|
||||
entity_id=42,
|
||||
label="Q3 Launch Notes",
|
||||
path="/documents/Launch/Q3.xml",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert '- document 42 — "Q3 Launch Notes" (/documents/Launch/Q3.xml)' in block
|
||||
|
||||
|
||||
def test_folder_with_path_renders_with_folder_kind() -> None:
|
||||
block = render_reference_pointers(
|
||||
[FolderReference(entity_id=7, label="Specs", path="/documents/Specs/")]
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert '- folder 7 — "Specs" (/documents/Specs/)' in block
|
||||
|
||||
|
||||
def test_connector_shows_provider_and_account() -> None:
|
||||
block = render_reference_pointers(
|
||||
[ConnectorReference(entity_id=12, label="work@acme.com", provider="Gmail")]
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert "- connector 12 — Gmail (work@acme.com)" in block
|
||||
|
||||
|
||||
def test_connector_without_provider_falls_back_to_label() -> None:
|
||||
block = render_reference_pointers(
|
||||
[ConnectorReference(entity_id=12, label="work@acme.com")]
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert "- connector 12 — work@acme.com" in block
|
||||
|
||||
|
||||
def test_chat_shows_quoted_title() -> None:
|
||||
block = render_reference_pointers(
|
||||
[ChatReference(entity_id=5, label="Pricing debate")]
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert '- chat 5 — "Pricing debate"' in block
|
||||
|
||||
|
||||
def test_label_whitespace_is_collapsed_to_one_line() -> None:
|
||||
block = render_reference_pointers(
|
||||
[DocumentReference(entity_id=1, label="line one\nline two", path="/d.xml")]
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert '- document 1 — "line one line two"' in block
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
"""Tests for the shared ``web_search`` tool's citable-result adaptation.
|
||||
|
||||
The tool's network path (SearXNG + live connectors) is out of scope here; these
|
||||
cover the pure mapping from raw web results to renderable, citable documents and
|
||||
the end-to-end registration of ``WEB_RESULT`` ``[n]`` labels.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import (
|
||||
CitationRegistry,
|
||||
CitationSourceType,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.document_render import render_web_results
|
||||
from app.agents.chat.shared.tools.web_search import (
|
||||
_to_renderable_web_documents,
|
||||
_web_source_label,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _raw_result(url: str, title: str, content: str) -> dict:
|
||||
return {
|
||||
"document": {"title": title, "metadata": {"url": url}},
|
||||
"content": content,
|
||||
}
|
||||
|
||||
|
||||
def test_web_source_label_strips_scheme_and_www() -> None:
|
||||
assert _web_source_label("https://www.example.com/path") == "Web · example.com"
|
||||
assert _web_source_label("http://news.site.org/a/b") == "Web · news.site.org"
|
||||
assert _web_source_label("") == "Web"
|
||||
|
||||
|
||||
def test_adapter_maps_each_result_to_one_web_passage() -> None:
|
||||
docs = _to_renderable_web_documents(
|
||||
[
|
||||
_raw_result("https://a.com/x", "Alpha", "alpha body"),
|
||||
_raw_result("https://b.com/y", "Beta", "beta body"),
|
||||
]
|
||||
)
|
||||
|
||||
assert [d.title for d in docs] == ["Alpha", "Beta"]
|
||||
passages = [p for d in docs for p in d.passages]
|
||||
assert all(p.source_type is CitationSourceType.WEB_RESULT for p in passages)
|
||||
assert passages[0].locator == {"url": "https://a.com/x"}
|
||||
assert passages[0].content == "alpha body"
|
||||
|
||||
|
||||
def test_adapter_skips_results_without_url_or_content() -> None:
|
||||
docs = _to_renderable_web_documents(
|
||||
[
|
||||
_raw_result("", "No URL", "has content"),
|
||||
_raw_result("https://c.com/z", "Empty", " "),
|
||||
_raw_result("https://d.com/w", "Good", "real content"),
|
||||
]
|
||||
)
|
||||
|
||||
assert [d.title for d in docs] == ["Good"]
|
||||
|
||||
|
||||
def test_adapter_truncates_on_char_budget() -> None:
|
||||
big = "x" * 30
|
||||
docs = _to_renderable_web_documents(
|
||||
[
|
||||
_raw_result("https://a.com", "A", big),
|
||||
_raw_result("https://b.com", "B", big),
|
||||
_raw_result("https://c.com", "C", big),
|
||||
],
|
||||
max_chars=50,
|
||||
)
|
||||
|
||||
# First fits (30), second crosses 50 and stops the loop.
|
||||
assert [d.title for d in docs] == ["A"]
|
||||
|
||||
|
||||
def test_end_to_end_registers_web_results_for_citation() -> None:
|
||||
registry = CitationRegistry()
|
||||
docs = _to_renderable_web_documents(
|
||||
[_raw_result("https://example.com/a", "Example", "the answer is 42")]
|
||||
)
|
||||
|
||||
block = render_web_results(docs, registry)
|
||||
|
||||
assert block is not None
|
||||
assert "[1] the answer is 42" in block
|
||||
entry = registry.resolve(1)
|
||||
assert entry is not None
|
||||
assert entry.source_type is CitationSourceType.WEB_RESULT
|
||||
assert entry.locator == {"url": "https://example.com/a"}
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
"""Tests for citation-entry → frontend payload mapping."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations.markers import (
|
||||
to_frontend_payload,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.citations.models import (
|
||||
CitationEntry,
|
||||
CitationSourceType,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _entry(source_type: CitationSourceType, locator: dict) -> CitationEntry:
|
||||
return CitationEntry(n=1, source_type=source_type, locator=locator)
|
||||
|
||||
|
||||
def test_kb_chunk_maps_to_chunk_id() -> None:
|
||||
entry = _entry(CitationSourceType.KB_CHUNK, {"chunk_id": 42, "document_id": 7})
|
||||
|
||||
assert to_frontend_payload(entry) == "42"
|
||||
|
||||
|
||||
def test_anon_chunk_keeps_negative_id() -> None:
|
||||
entry = _entry(CitationSourceType.ANON_CHUNK, {"chunk_id": -3})
|
||||
|
||||
assert to_frontend_payload(entry) == "-3"
|
||||
|
||||
|
||||
def test_web_result_maps_to_url() -> None:
|
||||
entry = _entry(CitationSourceType.WEB_RESULT, {"url": "https://example.com/a"})
|
||||
|
||||
assert to_frontend_payload(entry) == "https://example.com/a"
|
||||
|
||||
|
||||
def test_not_yet_renderable_kind_is_dropped() -> None:
|
||||
entry = _entry(CitationSourceType.CHAT_TURN, {"thread_id": 1, "turn": 2})
|
||||
|
||||
assert to_frontend_payload(entry) is None
|
||||
|
||||
|
||||
def test_missing_locator_field_is_dropped() -> None:
|
||||
entry = _entry(CitationSourceType.KB_CHUNK, {})
|
||||
|
||||
assert to_frontend_payload(entry) is None
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
"""Tests for rewriting model ``[n]`` ordinals into frontend citation markers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations.models import CitationSourceType
|
||||
from app.agents.chat.multi_agent_chat.shared.citations.normalizer import (
|
||||
normalize_citations,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.citations.registry import CitationRegistry
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _registry_with_chunks(*chunk_ids: int) -> CitationRegistry:
|
||||
registry = CitationRegistry()
|
||||
for chunk_id in chunk_ids:
|
||||
registry.register(CitationSourceType.KB_CHUNK, {"chunk_id": chunk_id})
|
||||
return registry
|
||||
|
||||
|
||||
def test_single_ordinal_is_rewritten() -> None:
|
||||
registry = _registry_with_chunks(42)
|
||||
|
||||
assert normalize_citations("We shipped it [1].", registry) == (
|
||||
"We shipped it [citation:42]."
|
||||
)
|
||||
|
||||
|
||||
def test_adjacent_brackets_are_each_rewritten() -> None:
|
||||
registry = _registry_with_chunks(42, 7)
|
||||
|
||||
assert normalize_citations("Both agree [1][2].", registry) == (
|
||||
"Both agree [citation:42][citation:7]."
|
||||
)
|
||||
|
||||
|
||||
def test_comma_separated_brackets_are_each_rewritten() -> None:
|
||||
registry = _registry_with_chunks(42, 7)
|
||||
|
||||
assert normalize_citations("Both agree [1], [2].", registry) == (
|
||||
"Both agree [citation:42], [citation:7]."
|
||||
)
|
||||
|
||||
|
||||
def test_unknown_ordinal_is_dropped() -> None:
|
||||
registry = _registry_with_chunks(42)
|
||||
|
||||
assert normalize_citations("Maybe [9] is real.", registry) == "Maybe is real."
|
||||
|
||||
|
||||
def test_unknown_ordinal_among_known_is_dropped() -> None:
|
||||
registry = _registry_with_chunks(42)
|
||||
|
||||
assert normalize_citations("See [1][9].", registry) == "See [citation:42]."
|
||||
|
||||
|
||||
def test_web_result_rewrites_to_url() -> None:
|
||||
registry = CitationRegistry()
|
||||
registry.register(CitationSourceType.WEB_RESULT, {"url": "https://example.com"})
|
||||
|
||||
assert normalize_citations("Per the docs [1].", registry) == (
|
||||
"Per the docs [citation:https://example.com]."
|
||||
)
|
||||
|
||||
|
||||
def test_word_glued_citation_is_rewritten() -> None:
|
||||
# The model frequently writes citations glued to the preceding word
|
||||
# (``docs[1]``); these must still resolve to a marker, not leak as raw text.
|
||||
registry = _registry_with_chunks(42)
|
||||
|
||||
assert normalize_citations("verifying against docs[1].", registry) == (
|
||||
"verifying against docs[citation:42]."
|
||||
)
|
||||
|
||||
|
||||
def test_word_glued_unknown_ordinal_drops() -> None:
|
||||
# A glued ordinal that doesn't resolve drops harmlessly (no broken marker,
|
||||
# no raw ``[n]`` leak) rather than being preserved as array-index syntax.
|
||||
registry = _registry_with_chunks(42)
|
||||
|
||||
assert normalize_citations("see notes[9] later", registry) == "see notes later"
|
||||
|
||||
|
||||
def test_array_index_inside_code_is_left_alone() -> None:
|
||||
# Genuine array/index syntax is protected by the code-region carve-out.
|
||||
registry = _registry_with_chunks(42)
|
||||
|
||||
assert normalize_citations("Read `arr[1]` carefully.", registry) == (
|
||||
"Read `arr[1]` carefully."
|
||||
)
|
||||
|
||||
|
||||
def test_ordinals_inside_inline_code_are_untouched() -> None:
|
||||
registry = _registry_with_chunks(42)
|
||||
|
||||
assert normalize_citations("Use `list[1]` here [1].", registry) == (
|
||||
"Use `list[1]` here [citation:42]."
|
||||
)
|
||||
|
||||
|
||||
def test_ordinals_inside_fenced_code_are_untouched() -> None:
|
||||
registry = _registry_with_chunks(42)
|
||||
text = "Before [1].\n```\nx = a[1]\n```\nAfter [1]."
|
||||
|
||||
assert normalize_citations(text, registry) == (
|
||||
"Before [citation:42].\n```\nx = a[1]\n```\nAfter [citation:42]."
|
||||
)
|
||||
|
||||
|
||||
def test_empty_text_is_returned_unchanged() -> None:
|
||||
assert normalize_citations("", _registry_with_chunks(42)) == ""
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
"""Unit tests for the citation registry spine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import (
|
||||
CitationRegistry,
|
||||
CitationSourceType,
|
||||
make_key,
|
||||
)
|
||||
|
||||
|
||||
def test_register_assigns_monotonic_labels() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
first = registry.register(
|
||||
CitationSourceType.KB_CHUNK, {"document_id": 42, "chunk_id": 880}
|
||||
)
|
||||
second = registry.register(
|
||||
CitationSourceType.KB_CHUNK, {"document_id": 42, "chunk_id": 881}
|
||||
)
|
||||
|
||||
assert (first, second) == (1, 2)
|
||||
assert registry.next_n == 3
|
||||
|
||||
|
||||
def test_register_is_find_or_create_for_same_unit() -> None:
|
||||
registry = CitationRegistry()
|
||||
locator = {"document_id": 42, "chunk_id": 880}
|
||||
|
||||
first = registry.register(CitationSourceType.KB_CHUNK, locator)
|
||||
again = registry.register(CitationSourceType.KB_CHUNK, locator)
|
||||
|
||||
assert first == again == 1
|
||||
assert len(registry.by_n) == 1
|
||||
assert registry.next_n == 2
|
||||
|
||||
|
||||
def test_dedup_is_insensitive_to_locator_key_order() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
first = registry.register(
|
||||
CitationSourceType.KB_CHUNK, {"document_id": 42, "chunk_id": 880}
|
||||
)
|
||||
reordered = registry.register(
|
||||
CitationSourceType.KB_CHUNK, {"chunk_id": 880, "document_id": 42}
|
||||
)
|
||||
|
||||
assert first == reordered
|
||||
|
||||
|
||||
def test_same_locator_values_across_types_do_not_collide() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
chunk = registry.register(CitationSourceType.KB_CHUNK, {"id": 7})
|
||||
chat = registry.register(CitationSourceType.CHAT_TURN, {"id": 7})
|
||||
|
||||
assert chunk != chat
|
||||
|
||||
|
||||
def test_resolve_returns_entry_with_locator_and_display() -> None:
|
||||
registry = CitationRegistry()
|
||||
n = registry.register(
|
||||
CitationSourceType.WEB_RESULT,
|
||||
{"url": "https://example.com"},
|
||||
{"title": "Example"},
|
||||
)
|
||||
|
||||
entry = registry.resolve(n)
|
||||
|
||||
assert entry is not None
|
||||
assert entry.n == n
|
||||
assert entry.source_type is CitationSourceType.WEB_RESULT
|
||||
assert entry.locator == {"url": "https://example.com"}
|
||||
assert entry.display == {"title": "Example"}
|
||||
|
||||
|
||||
def test_resolve_unknown_label_returns_none() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
assert registry.resolve(999) is None
|
||||
|
||||
|
||||
def test_registry_round_trips_through_serialization() -> None:
|
||||
registry = CitationRegistry()
|
||||
registry.register(
|
||||
CitationSourceType.KB_CHUNK,
|
||||
{"document_id": 42, "chunk_id": 880},
|
||||
{"title": "Q3 Launch Notes"},
|
||||
)
|
||||
|
||||
restored = CitationRegistry.model_validate(registry.model_dump())
|
||||
|
||||
entry = restored.resolve(1)
|
||||
assert entry is not None
|
||||
assert entry.source_type is CitationSourceType.KB_CHUNK
|
||||
assert restored.next_n == registry.next_n
|
||||
|
||||
|
||||
def test_make_key_is_stable_and_type_prefixed() -> None:
|
||||
key_a = make_key(CitationSourceType.KB_CHUNK, {"document_id": 42, "chunk_id": 880})
|
||||
key_b = make_key(CitationSourceType.KB_CHUNK, {"chunk_id": 880, "document_id": 42})
|
||||
|
||||
assert key_a == key_b
|
||||
assert key_a.startswith("kb_chunk|")
|
||||
|
||||
|
||||
def _kb(registry: CitationRegistry, chunk_id: int) -> int:
|
||||
return registry.register(
|
||||
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
|
||||
)
|
||||
|
||||
|
||||
def test_merge_unions_disjoint_registries_preserving_labels() -> None:
|
||||
left = CitationRegistry()
|
||||
_kb(left, 10) # [1]
|
||||
_kb(left, 11) # [2]
|
||||
|
||||
# A branch that forked from `left`, then registered its own chunk at [3].
|
||||
right = left.model_copy(deep=True)
|
||||
third = _kb(right, 12) # [3]
|
||||
assert third == 3
|
||||
|
||||
merged = left.merge(right)
|
||||
|
||||
assert merged.resolve(1).locator["chunk_id"] == 10
|
||||
assert merged.resolve(2).locator["chunk_id"] == 11
|
||||
assert merged.resolve(3).locator["chunk_id"] == 12
|
||||
assert merged.next_n == 4
|
||||
|
||||
|
||||
def test_merge_keeps_one_label_for_a_shared_source() -> None:
|
||||
left = CitationRegistry()
|
||||
_kb(left, 10) # [1]
|
||||
right = CitationRegistry()
|
||||
_kb(right, 10) # also [1], same source
|
||||
|
||||
merged = left.merge(right)
|
||||
|
||||
assert len(merged.by_n) == 1
|
||||
assert merged.resolve(1).locator["chunk_id"] == 10
|
||||
assert merged.next_n == 2
|
||||
|
||||
|
||||
def test_merge_remints_on_collision_without_losing_sources() -> None:
|
||||
# Two branches forked from the same base [1], each minting a *different*
|
||||
# source at [2]. Merge must keep both sources, re-minting one.
|
||||
base = CitationRegistry()
|
||||
_kb(base, 10) # [1]
|
||||
|
||||
left = base.model_copy(deep=True)
|
||||
_kb(left, 11) # [2] -> chunk 11
|
||||
|
||||
right = base.model_copy(deep=True)
|
||||
_kb(right, 12) # [2] -> chunk 12 (collision)
|
||||
|
||||
merged = left.merge(right)
|
||||
|
||||
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
|
||||
assert chunk_ids == {10, 11, 12}
|
||||
assert merged.resolve(2).locator["chunk_id"] == 11 # left wins the slot
|
||||
assert merged.resolve(3).locator["chunk_id"] == 12 # right re-minted
|
||||
assert merged.next_n == 4
|
||||
|
||||
|
||||
def test_merge_does_not_mutate_inputs() -> None:
|
||||
left = CitationRegistry()
|
||||
_kb(left, 10)
|
||||
right = CitationRegistry()
|
||||
_kb(right, 11)
|
||||
|
||||
left.merge(right)
|
||||
|
||||
assert list(left.by_n) == [1]
|
||||
assert list(right.by_n) == [1]
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
"""Tests for the shared ``render_document`` (one ``<document>`` block)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import (
|
||||
CitationRegistry,
|
||||
CitationSourceType,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.document_render import (
|
||||
RenderableDocument,
|
||||
RenderablePassage,
|
||||
render_document,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _document(
|
||||
document_id: int,
|
||||
title: str,
|
||||
chunk_ids: list[int],
|
||||
*,
|
||||
source: str | None = None,
|
||||
) -> RenderableDocument:
|
||||
return RenderableDocument(
|
||||
title=title,
|
||||
source=source,
|
||||
passages=[
|
||||
RenderablePassage(
|
||||
content=f"text {cid}",
|
||||
locator={"document_id": document_id, "chunk_id": cid},
|
||||
)
|
||||
for cid in chunk_ids
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_returns_none_when_no_passages() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
assert (
|
||||
render_document(_document(1, "Empty", []), view="excerpt", registry=registry)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_excerpt_open_and_close_tags() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
block = render_document(
|
||||
_document(1, "Q3 Launch Notes", [880], source="Slack · #launch"),
|
||||
view="excerpt",
|
||||
registry=registry,
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert block.startswith(
|
||||
'<document title="Q3 Launch Notes" source="Slack · #launch" view="excerpt">'
|
||||
)
|
||||
assert block.endswith("</document>")
|
||||
|
||||
|
||||
def test_full_view_renders_view_attribute() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
block = render_document(_document(1, "Doc", [880]), view="full", registry=registry)
|
||||
|
||||
assert block is not None
|
||||
assert '<document title="Doc" view="full">' in block
|
||||
|
||||
|
||||
def test_source_attribute_omitted_when_absent() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
block = render_document(
|
||||
_document(1, "Plain", [1]), view="excerpt", registry=registry
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert block.startswith('<document title="Plain" view="excerpt">')
|
||||
|
||||
|
||||
def test_registers_passages_with_chunk_locators() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
render_document(
|
||||
_document(1, "Doc", [880], source="Slack"),
|
||||
view="excerpt",
|
||||
registry=registry,
|
||||
)
|
||||
|
||||
entry = registry.resolve(1)
|
||||
assert entry is not None
|
||||
assert entry.source_type is CitationSourceType.KB_CHUNK
|
||||
assert entry.locator == {"document_id": 1, "chunk_id": 880}
|
||||
assert entry.display == {"title": "Doc", "source": "Slack"}
|
||||
|
||||
|
||||
def test_passages_get_monotonic_labels() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
block = render_document(
|
||||
_document(1, "Doc", [880, 881]), view="excerpt", registry=registry
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert " [1] text 880" in block
|
||||
assert " [2] text 881" in block
|
||||
|
||||
|
||||
def test_multiline_passage_indents_under_label() -> None:
|
||||
registry = CitationRegistry()
|
||||
document = RenderableDocument(
|
||||
title="Doc",
|
||||
passages=[
|
||||
RenderablePassage(
|
||||
content="line one\nline two",
|
||||
locator={"document_id": 1, "chunk_id": 5},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
block = render_document(document, view="excerpt", registry=registry)
|
||||
|
||||
assert block is not None
|
||||
assert " [1] line one\n line two" in block
|
||||
|
||||
|
||||
def test_attribute_values_are_escaped() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
block = render_document(
|
||||
_document(1, 'A & B <c> "d"', [1], source="x & y"),
|
||||
view="excerpt",
|
||||
registry=registry,
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert 'title="A & B <c> "d""' in block
|
||||
assert 'source="x & y"' in block
|
||||
|
||||
|
||||
def test_same_passage_reuses_label_across_calls() -> None:
|
||||
registry = CitationRegistry()
|
||||
document = _document(1, "Doc", [880])
|
||||
|
||||
render_document(document, view="excerpt", registry=registry)
|
||||
render_document(document, view="full", registry=registry)
|
||||
|
||||
assert registry.next_n == 2
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
"""Tests for the ``<retrieved_context>`` wrapper around excerpt documents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
|
||||
from app.agents.chat.multi_agent_chat.shared.document_render import (
|
||||
RenderableDocument,
|
||||
RenderablePassage,
|
||||
render_search_context,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _document(
|
||||
document_id: int,
|
||||
title: str,
|
||||
chunk_ids: list[int],
|
||||
*,
|
||||
source: str | None = None,
|
||||
) -> RenderableDocument:
|
||||
return RenderableDocument(
|
||||
title=title,
|
||||
source=source,
|
||||
passages=[
|
||||
RenderablePassage(
|
||||
content=f"text {cid}",
|
||||
locator={"document_id": document_id, "chunk_id": cid},
|
||||
)
|
||||
for cid in chunk_ids
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_returns_none_when_nothing_to_show() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
assert render_search_context([], registry) is None
|
||||
assert render_search_context([_document(1, "Empty", [])], registry) is None
|
||||
|
||||
|
||||
def test_assigns_monotonic_labels_across_documents() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
block = render_search_context(
|
||||
[
|
||||
_document(1, "Q3 Launch Notes", [880, 881], source="Slack"),
|
||||
_document(2, "Timeline", [12], source="Notion"),
|
||||
],
|
||||
registry,
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert "[1] text 880" in block
|
||||
assert "[2] text 881" in block
|
||||
assert "[3] text 12" in block
|
||||
|
||||
|
||||
def test_wraps_in_retrieved_context_and_teaches_excerpt_and_citation() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
block = render_search_context([_document(1, "Doc", [1])], registry)
|
||||
|
||||
assert block is not None
|
||||
assert block.startswith("<retrieved_context>")
|
||||
assert block.endswith("</retrieved_context>")
|
||||
assert "excerpt view" in block
|
||||
assert "Cite a chunk with its [n]." in block
|
||||
|
||||
|
||||
def test_documents_render_as_excerpt_blocks() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
block = render_search_context(
|
||||
[_document(1, "Q3", [1], source="Slack · #launch")], registry
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert '<document title="Q3" source="Slack · #launch" view="excerpt">' in block
|
||||
assert "</document>" in block
|
||||
|
||||
|
||||
def test_same_passage_reuses_label_across_calls() -> None:
|
||||
registry = CitationRegistry()
|
||||
document = _document(1, "Doc", [880])
|
||||
|
||||
render_search_context([document], registry)
|
||||
block = render_search_context([document], registry)
|
||||
|
||||
assert block is not None
|
||||
assert "[1] text 880" in block
|
||||
assert registry.next_n == 2
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
"""Tests for building a document's source label."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.document_render import source_label
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_known_type_uses_friendly_name() -> None:
|
||||
assert source_label("SLACK_CONNECTOR", {}) == "Slack"
|
||||
|
||||
|
||||
def test_unmapped_type_is_prettified() -> None:
|
||||
assert source_label("GOOGLE_DRIVE_FILE", {}) == "Google Drive"
|
||||
|
||||
|
||||
def test_url_host_is_appended_and_www_stripped() -> None:
|
||||
label = source_label("CRAWLED_URL", {"url": "https://www.docs.python.org/3/"})
|
||||
|
||||
assert label == "Web · docs.python.org"
|
||||
|
||||
|
||||
def test_host_only_when_type_unknown() -> None:
|
||||
assert source_label(None, {"url": "https://example.com/a"}) == "example.com"
|
||||
|
||||
|
||||
def test_returns_none_when_nothing_known() -> None:
|
||||
assert source_label(None, {}) is None
|
||||
|
||||
|
||||
def test_non_http_url_is_ignored() -> None:
|
||||
assert source_label("FILE", {"url": "/local/path"}) == "File"
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
"""Tests for the ``<web_results>`` wrapper around web-result excerpt documents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import (
|
||||
CitationRegistry,
|
||||
CitationSourceType,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.document_render import (
|
||||
RenderableDocument,
|
||||
RenderablePassage,
|
||||
render_web_results,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _web_doc(url: str, title: str, content: str) -> RenderableDocument:
|
||||
return RenderableDocument(
|
||||
title=title,
|
||||
source=f"Web · {url.split('//', 1)[-1].split('/', 1)[0]}",
|
||||
passages=[
|
||||
RenderablePassage(
|
||||
content=content,
|
||||
locator={"url": url},
|
||||
source_type=CitationSourceType.WEB_RESULT,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_returns_none_when_nothing_to_show() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
assert render_web_results([], registry) is None
|
||||
|
||||
|
||||
def test_wraps_in_web_results_container() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
block = render_web_results(
|
||||
[_web_doc("https://example.com/a", "Example", "the answer is 42")],
|
||||
registry,
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
assert block.startswith("<web_results>")
|
||||
assert block.endswith("</web_results>")
|
||||
assert "cite a result with its [n]" in block
|
||||
assert '<document title="Example" source="Web · example.com" view="excerpt">' in block
|
||||
assert "[1] the answer is 42" in block
|
||||
|
||||
|
||||
def test_registers_each_result_as_web_result_with_url_locator() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
render_web_results(
|
||||
[
|
||||
_web_doc("https://a.com/x", "A", "alpha"),
|
||||
_web_doc("https://b.com/y", "B", "beta"),
|
||||
],
|
||||
registry,
|
||||
)
|
||||
|
||||
first = registry.resolve(1)
|
||||
second = registry.resolve(2)
|
||||
assert first is not None and second is not None
|
||||
assert first.source_type is CitationSourceType.WEB_RESULT
|
||||
assert first.locator == {"url": "https://a.com/x"}
|
||||
assert second.locator == {"url": "https://b.com/y"}
|
||||
|
||||
|
||||
def test_same_url_reuses_label_across_calls() -> None:
|
||||
registry = CitationRegistry()
|
||||
doc = _web_doc("https://example.com/a", "Example", "stable fact")
|
||||
|
||||
render_web_results([doc], registry)
|
||||
render_web_results([doc], registry)
|
||||
|
||||
assert registry.next_n == 2
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
"""Tests for mapping a DocumentHit to a renderable document."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.retrieval.adapter import (
|
||||
to_renderable_document,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.retrieval.models import (
|
||||
ChunkHit,
|
||||
DocumentHit,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_maps_identity_source_and_passages() -> None:
|
||||
hit = DocumentHit(
|
||||
document_id=42,
|
||||
title="Q3 Launch Notes",
|
||||
document_type="SLACK_CONNECTOR",
|
||||
metadata={},
|
||||
score=0.9,
|
||||
chunks=[
|
||||
ChunkHit(chunk_id=880, content="a", position=4, score=0.9),
|
||||
ChunkHit(chunk_id=881, content="b", position=7, score=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
document = to_renderable_document(hit)
|
||||
|
||||
assert document.title == "Q3 Launch Notes"
|
||||
assert document.source == "Slack"
|
||||
assert [
|
||||
(p.locator["chunk_id"], p.content) for p in document.passages
|
||||
] == [(880, "a"), (881, "b")]
|
||||
assert all(p.locator["document_id"] == 42 for p in document.passages)
|
||||
|
||||
|
||||
def test_document_with_no_chunks_maps_to_no_passages() -> None:
|
||||
hit = DocumentHit(
|
||||
document_id=1,
|
||||
title="Empty",
|
||||
document_type=None,
|
||||
metadata={},
|
||||
score=0.0,
|
||||
chunks=[],
|
||||
)
|
||||
|
||||
assert to_renderable_document(hit).passages == []
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
"""Tests for the build_context pipeline (rerank → adapt → render)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
|
||||
from app.agents.chat.multi_agent_chat.shared.retrieval.models import (
|
||||
ChunkHit,
|
||||
DocumentHit,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.retrieval.service import build_context
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _hit(document_id: int, chunk_id: int) -> DocumentHit:
|
||||
return DocumentHit(
|
||||
document_id=document_id,
|
||||
title=f"Doc {document_id}",
|
||||
document_type="FILE",
|
||||
metadata={},
|
||||
score=1.0 / document_id,
|
||||
chunks=[ChunkHit(chunk_id=chunk_id, content=f"text {chunk_id}", position=0, score=1.0)],
|
||||
)
|
||||
|
||||
|
||||
def test_no_hits_renders_nothing() -> None:
|
||||
assert build_context("q", [], CitationRegistry()) is None
|
||||
|
||||
|
||||
def test_renders_block_and_registers_labels_in_order() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
block = build_context("q", [_hit(1, 880), _hit(2, 12)], registry)
|
||||
|
||||
assert block is not None
|
||||
assert "[1] text 880" in block
|
||||
assert "[2] text 12" in block
|
||||
assert registry.resolve(1).locator == {"document_id": 1, "chunk_id": 880}
|
||||
assert registry.resolve(2).locator == {"document_id": 2, "chunk_id": 12}
|
||||
|
||||
|
||||
class _ReverseReranker:
|
||||
"""Stand-in reranker that simply reverses document order."""
|
||||
|
||||
def rerank_documents(
|
||||
self, query_text: str, documents: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
return list(reversed(documents))
|
||||
|
||||
|
||||
def test_reranker_reorders_documents_before_labeling() -> None:
|
||||
registry = CitationRegistry()
|
||||
|
||||
block = build_context(
|
||||
"q", [_hit(1, 880), _hit(2, 12)], registry, reranker=_ReverseReranker()
|
||||
)
|
||||
|
||||
assert block is not None
|
||||
# Reversed: doc 2 now renders first and gets [1].
|
||||
assert registry.resolve(1).locator == {"document_id": 2, "chunk_id": 12}
|
||||
assert registry.resolve(2).locator == {"document_id": 1, "chunk_id": 880}
|
||||
|
|
@ -1,295 +0,0 @@
|
|||
"""Tests for the prompt fragment composer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import ChatVisibility
|
||||
from app.prompts.system_prompt_composer.composer import (
|
||||
ALL_TOOL_NAMES_ORDERED,
|
||||
compose_system_prompt,
|
||||
detect_provider_variant,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixed_today() -> datetime:
|
||||
return datetime(2025, 6, 1, 12, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
class TestProviderVariantDetection:
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,expected",
|
||||
[
|
||||
# GPT-4 family routes to "classic" (autonomous-persistence style)
|
||||
("openai:gpt-4o-mini", "openai_classic"),
|
||||
("openai:gpt-4-turbo", "openai_classic"),
|
||||
# GPT-5 / o-series route to "reasoning" (channel-aware pragmatic)
|
||||
("openai:gpt-5", "openai_reasoning"),
|
||||
("openai:o1-preview", "openai_reasoning"),
|
||||
("openai:o3-mini", "openai_reasoning"),
|
||||
# Codex family beats reasoning (more specific). Mirrors OpenCode
|
||||
# ``system.ts`` — ``gpt-*-codex`` gets the code-purist prompt.
|
||||
("openai:gpt-5-codex", "openai_codex"),
|
||||
("openai:gpt-codex", "openai_codex"),
|
||||
("openai:codex-mini", "openai_codex"),
|
||||
# Anthropic + Google
|
||||
("anthropic:claude-3-5-sonnet", "anthropic"),
|
||||
("anthropic/claude-opus-4", "anthropic"),
|
||||
("google:gemini-2.0-flash", "google"),
|
||||
("vertex:gemini-1.5-pro", "google"),
|
||||
# Newly-covered families
|
||||
("moonshot:kimi-k2", "kimi"),
|
||||
("openrouter:moonshot/kimi-k2.5", "kimi"),
|
||||
("xai:grok-2", "grok"),
|
||||
("openrouter:x-ai/grok-3", "grok"),
|
||||
("openai:deepseek-v3", "deepseek"),
|
||||
("deepseek:deepseek-r1", "deepseek"),
|
||||
# Unknown families fall back to default (no provider block emitted)
|
||||
("groq:mixtral-8x7b", "default"),
|
||||
("together:llama-3.1-70b", "default"),
|
||||
(None, "default"),
|
||||
("", "default"),
|
||||
],
|
||||
)
|
||||
def test_detection(self, model_name: str | None, expected: str) -> None:
|
||||
assert detect_provider_variant(model_name) == expected
|
||||
|
||||
def test_codex_takes_precedence_over_reasoning(self) -> None:
|
||||
"""Regression guard: ``gpt-5-codex`` must NOT match the generic
|
||||
``gpt-5`` reasoning regex first. Codex is the more specialised
|
||||
prompt and mirrors OpenCode's dispatch order.
|
||||
"""
|
||||
from app.prompts.system_prompt_composer.composer import detect_provider_variant
|
||||
|
||||
assert detect_provider_variant("openai:gpt-5-codex") == "openai_codex"
|
||||
assert detect_provider_variant("openai:gpt-5") == "openai_reasoning"
|
||||
|
||||
|
||||
class TestCompose:
|
||||
def test_default_prompt_has_required_blocks(self, fixed_today: datetime) -> None:
|
||||
prompt = compose_system_prompt(today=fixed_today)
|
||||
# System instruction wrapper
|
||||
assert "<system_instruction>" in prompt
|
||||
assert "</system_instruction>" in prompt
|
||||
# Date interpolated
|
||||
assert "2025-06-01" in prompt
|
||||
# Core policy blocks present
|
||||
assert "<knowledge_base_only_policy>" in prompt
|
||||
assert "<tool_routing>" in prompt
|
||||
assert "<parameter_resolution>" in prompt
|
||||
assert "<memory_protocol>" in prompt
|
||||
# Tools
|
||||
assert "<tools>" in prompt
|
||||
assert "</tools>" in prompt
|
||||
# Citations on by default
|
||||
assert "<citation_instructions>" in prompt
|
||||
assert "[citation:chunk_id]" in prompt
|
||||
|
||||
def test_team_visibility_uses_team_variants(self, fixed_today: datetime) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
thread_visibility=ChatVisibility.SEARCH_SPACE,
|
||||
)
|
||||
# Team-specific phrasing in the agent block
|
||||
assert "team space" in prompt
|
||||
# Memory protocol mentions team
|
||||
assert "team" in prompt
|
||||
# Should NOT mention the user-only memory phrasing
|
||||
assert "personal knowledge base" not in prompt
|
||||
|
||||
def test_private_visibility_uses_private_variants(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
thread_visibility=ChatVisibility.PRIVATE,
|
||||
)
|
||||
assert "personal knowledge base" in prompt
|
||||
# Should NOT mention the team-specific phrasing about prefixed authors
|
||||
assert "[DisplayName of the author]" not in prompt
|
||||
|
||||
def test_citations_disabled_swaps_block(self, fixed_today: datetime) -> None:
|
||||
prompt_on = compose_system_prompt(today=fixed_today, citations_enabled=True)
|
||||
prompt_off = compose_system_prompt(today=fixed_today, citations_enabled=False)
|
||||
assert "Citations are DISABLED" in prompt_off
|
||||
assert "Citations are DISABLED" not in prompt_on
|
||||
assert "[citation:chunk_id]" in prompt_on
|
||||
|
||||
def test_enabled_tool_filter_only_includes_listed_tools(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
enabled_tool_names={"web_search", "scrape_webpage"},
|
||||
)
|
||||
assert "web_search:" in prompt or "- web_search:" in prompt
|
||||
assert "scrape_webpage:" in prompt or "- scrape_webpage:" in prompt
|
||||
# Excluded tools should NOT appear in tool listing
|
||||
assert "generate_podcast:" not in prompt
|
||||
assert "generate_image:" not in prompt
|
||||
|
||||
def test_disabled_tool_note_is_appended(self, fixed_today: datetime) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
enabled_tool_names={"web_search"},
|
||||
disabled_tool_names={"generate_image", "generate_podcast"},
|
||||
)
|
||||
assert "DISABLED TOOLS (by user):" in prompt
|
||||
assert "Generate Image" in prompt
|
||||
assert "Generate Podcast" in prompt
|
||||
|
||||
def test_mcp_routing_block_emits_when_provided(self, fixed_today: datetime) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]},
|
||||
)
|
||||
assert "<mcp_tool_routing>" in prompt
|
||||
assert "My GitLab" in prompt
|
||||
assert "gitlab_search" in prompt
|
||||
|
||||
def test_mcp_routing_block_absent_when_no_servers(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={})
|
||||
assert "<mcp_tool_routing>" not in prompt
|
||||
|
||||
def test_provider_block_renders_when_anthropic(self, fixed_today: datetime) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today, model_name="anthropic:claude-3-5-sonnet"
|
||||
)
|
||||
assert "<provider_hints>" in prompt
|
||||
assert "Anthropic" in prompt or "Claude" in prompt
|
||||
|
||||
def test_provider_block_absent_for_default(self, fixed_today: datetime) -> None:
|
||||
prompt = compose_system_prompt(today=fixed_today, model_name="custom:foo")
|
||||
assert "<provider_hints>" not in prompt
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,expected_marker",
|
||||
[
|
||||
# Each marker is a unique-ish phrase from the corresponding fragment.
|
||||
# If a fragment is renamed/rewritten such that the marker is gone,
|
||||
# update both the fragment and this test deliberately.
|
||||
("openai:gpt-5-codex", "Codex-class"),
|
||||
("openai:gpt-5", "OpenAI reasoning model"),
|
||||
("openai:gpt-4o", "classic OpenAI chat model"),
|
||||
("anthropic:claude-3-5-sonnet", "Anthropic Claude"),
|
||||
("google:gemini-2.0-flash", "Google Gemini"),
|
||||
("moonshot:kimi-k2", "Moonshot Kimi"),
|
||||
("xai:grok-2", "xAI Grok"),
|
||||
("deepseek:deepseek-r1", "DeepSeek"),
|
||||
],
|
||||
)
|
||||
def test_each_known_variant_renders_with_its_marker(
|
||||
self,
|
||||
fixed_today: datetime,
|
||||
model_name: str,
|
||||
expected_marker: str,
|
||||
) -> None:
|
||||
"""Every supported variant must produce a ``<provider_hints>`` block
|
||||
containing its identifying marker. This pins the dispatch + the
|
||||
on-disk fragments together so a missing/renamed file is caught
|
||||
immediately.
|
||||
"""
|
||||
prompt = compose_system_prompt(today=fixed_today, model_name=model_name)
|
||||
assert "<provider_hints>" in prompt, (
|
||||
f"variant for {model_name!r} did not emit a provider_hints block; "
|
||||
"the corresponding providers/<variant>.md may be missing"
|
||||
)
|
||||
assert expected_marker in prompt, (
|
||||
f"variant for {model_name!r} emitted hints but lacked the "
|
||||
f"expected marker {expected_marker!r} — the fragment may have "
|
||||
"drifted from the dispatch table"
|
||||
)
|
||||
|
||||
def test_provider_blocks_are_byte_stable_across_calls(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
"""Cache-stability guard: same model id → byte-identical prompt."""
|
||||
a = compose_system_prompt(today=fixed_today, model_name="moonshot:kimi-k2")
|
||||
b = compose_system_prompt(today=fixed_today, model_name="moonshot:kimi-k2")
|
||||
assert a == b
|
||||
|
||||
def test_custom_system_instructions_override_default(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
custom = "You are a custom assistant. Today is {resolved_today}."
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today, custom_system_instructions=custom
|
||||
)
|
||||
assert "You are a custom assistant. Today is 2025-06-01." in prompt
|
||||
# Default block should NOT be present
|
||||
assert "<knowledge_base_only_policy>" not in prompt
|
||||
|
||||
def test_provider_hints_render_with_custom_system_instructions(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
"""Regression guard for the always-append decision: provider hints
|
||||
append AFTER a custom system prompt.
|
||||
|
||||
Provider hints are stylistic nudges (parallel tool-call rules,
|
||||
formatting guidance, etc.) that help the model regardless of
|
||||
what the system instructions say. Suppressing them when a
|
||||
custom prompt is set would partially defeat the per-family
|
||||
prompt machinery.
|
||||
"""
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
custom_system_instructions="You are a custom assistant.",
|
||||
model_name="anthropic/claude-3-5-sonnet",
|
||||
)
|
||||
assert "You are a custom assistant." in prompt
|
||||
assert "<provider_hints>" in prompt
|
||||
# The custom prompt must come BEFORE the provider hints so the
|
||||
# user's framing isn't drowned out by the stylistic nudges.
|
||||
assert prompt.index("You are a custom assistant.") < prompt.index(
|
||||
"<provider_hints>"
|
||||
)
|
||||
|
||||
def test_use_default_false_with_no_custom_yields_no_system_block(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
use_default_system_instructions=False,
|
||||
)
|
||||
# No system_instruction wrapper but tools/citations still emitted
|
||||
assert "<system_instruction>" not in prompt
|
||||
assert "<tools>" in prompt
|
||||
|
||||
def test_all_known_tools_have_fragments(self) -> None:
|
||||
# Soft assertion: verify that every tool in the canonical order
|
||||
# produces non-empty content for at least one variant.
|
||||
for tool in ALL_TOOL_NAMES_ORDERED:
|
||||
prompt = compose_system_prompt(
|
||||
today=datetime(2025, 1, 1, tzinfo=UTC),
|
||||
enabled_tool_names={tool},
|
||||
)
|
||||
assert tool in prompt, f"tool {tool!r} missing from composed prompt"
|
||||
|
||||
|
||||
class TestStableOrderingForCacheStability:
|
||||
"""Regression guard: prompt cache hit-rate depends on byte-stable prefix."""
|
||||
|
||||
def test_composition_is_deterministic_given_same_inputs(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
a = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
enabled_tool_names={"web_search", "scrape_webpage"},
|
||||
mcp_connector_tools={"X": ["x_a", "x_b"]},
|
||||
)
|
||||
b = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
enabled_tool_names={
|
||||
"scrape_webpage",
|
||||
"web_search",
|
||||
}, # set order shouldn't matter
|
||||
mcp_connector_tools={"X": ["x_a", "x_b"]},
|
||||
)
|
||||
assert a == b
|
||||
|
|
@ -38,7 +38,7 @@ class TestIsProtectedSystemMessage:
|
|||
)
|
||||
|
||||
def test_tolerates_leading_whitespace(self) -> None:
|
||||
msg = SystemMessage(content=" \n<priority_documents>\n...")
|
||||
msg = SystemMessage(content=" \n<workspace_tree>\n...")
|
||||
assert _is_protected_system_message(msg) is True
|
||||
|
||||
|
||||
|
|
@ -89,7 +89,7 @@ class TestPartitionMessages:
|
|||
|
||||
def test_protected_system_message_preserved_even_in_summarize_half(self) -> None:
|
||||
partitioner = self._build_partitioner()
|
||||
protected = SystemMessage(content="<priority_documents>\n...")
|
||||
protected = SystemMessage(content="<workspace_tree>\n...")
|
||||
msgs = [
|
||||
HumanMessage(content="old human"),
|
||||
AIMessage(content="old ai"),
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
|
||||
"SURFSENSE_ENABLE_SKILLS",
|
||||
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
|
||||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
|
|
@ -57,7 +56,6 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) ->
|
|||
assert flags.enable_llm_tool_selector is False
|
||||
assert flags.enable_skills is True
|
||||
assert flags.enable_specialized_subagents is True
|
||||
assert flags.enable_kb_planner_runnable is True
|
||||
assert flags.enable_action_log is True
|
||||
assert flags.enable_revert_route is True
|
||||
assert flags.enable_plugin_loader is False
|
||||
|
|
@ -122,7 +120,6 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
|
|||
"enable_llm_tool_selector": "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
|
||||
"enable_skills": "SURFSENSE_ENABLE_SKILLS",
|
||||
"enable_specialized_subagents": "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
|
||||
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
|
|
|
|||
|
|
@ -90,8 +90,8 @@ class TestSubstituteInText:
|
|||
|
||||
class TestResolveMentions:
|
||||
"""``resolve_mentions`` resolves chip ids → virtual paths and emits
|
||||
a ``ResolvedMentionSet`` whose id partitions feed
|
||||
``KnowledgePriorityMiddleware``."""
|
||||
a ``ResolvedMentionSet`` whose id partitions feed the
|
||||
``search_knowledge_base`` retrieval scope."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_when_no_mentions(self):
|
||||
|
|
|
|||
|
|
@ -4,9 +4,14 @@ from __future__ import annotations
|
|||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import (
|
||||
CitationRegistry,
|
||||
CitationSourceType,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.state.reducers import (
|
||||
_CLEAR,
|
||||
_add_unique_reducer,
|
||||
_citation_registry_merge_reducer,
|
||||
_dict_merge_with_tombstones_reducer,
|
||||
_initial_filesystem_state,
|
||||
_list_append_reducer,
|
||||
|
|
@ -93,6 +98,57 @@ class TestDictMergeWithTombstones:
|
|||
}
|
||||
|
||||
|
||||
def _kb_registry(chunk_id: int) -> CitationRegistry:
|
||||
registry = CitationRegistry()
|
||||
registry.register(
|
||||
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
|
||||
)
|
||||
return registry
|
||||
|
||||
|
||||
class TestCitationRegistryMergeReducer:
|
||||
def test_none_left_returns_right(self):
|
||||
right = _kb_registry(10)
|
||||
assert _citation_registry_merge_reducer(None, right) is right
|
||||
|
||||
def test_none_right_returns_left(self):
|
||||
left = _kb_registry(10)
|
||||
assert _citation_registry_merge_reducer(left, None) is left
|
||||
|
||||
def test_both_none_returns_none(self):
|
||||
assert _citation_registry_merge_reducer(None, None) is None
|
||||
|
||||
def test_unions_two_registries(self):
|
||||
left = _kb_registry(10)
|
||||
right = _kb_registry(11)
|
||||
|
||||
merged = _citation_registry_merge_reducer(left, right)
|
||||
|
||||
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
|
||||
assert chunk_ids == {10, 11}
|
||||
|
||||
def test_coerces_serialized_dict_update(self):
|
||||
# The checkpointer serializes Command.update via ormsgpack before the
|
||||
# reducer runs, so `right` can arrive as a plain dict.
|
||||
left = _kb_registry(10)
|
||||
right = _kb_registry(11).model_dump()
|
||||
|
||||
merged = _citation_registry_merge_reducer(left, right)
|
||||
|
||||
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
|
||||
assert chunk_ids == {10, 11}
|
||||
|
||||
def test_coerces_both_sides_from_dict(self):
|
||||
left = _kb_registry(10).model_dump()
|
||||
right = _kb_registry(11).model_dump()
|
||||
|
||||
merged = _citation_registry_merge_reducer(left, right)
|
||||
|
||||
assert isinstance(merged, CitationRegistry)
|
||||
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
|
||||
assert chunk_ids == {10, 11}
|
||||
|
||||
|
||||
class TestInitialFilesystemState:
|
||||
def test_default_shape(self):
|
||||
state = _initial_filesystem_state()
|
||||
|
|
@ -105,8 +161,6 @@ class TestInitialFilesystemState:
|
|||
assert state["doc_id_by_path"] == {}
|
||||
assert state["dirty_paths"] == []
|
||||
assert state["dirty_path_tool_calls"] == {}
|
||||
assert state["kb_priority"] == []
|
||||
assert state["kb_matched_chunk_ids"] == {}
|
||||
assert state["kb_anon_doc"] is None
|
||||
assert state["tree_version"] == 0
|
||||
|
||||
|
|
|
|||
124
surfsense_backend/tests/unit/middleware/test_kb_postgres_read.py
Normal file
124
surfsense_backend/tests/unit/middleware/test_kb_postgres_read.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""Unit tests for the KB read path: full-view render + anonymous-doc loading.
|
||||
|
||||
DB-backed loads are exercised by the integration suite; here we lock the pure
|
||||
pieces — ``render_full_document`` and the anonymous-upload branch of
|
||||
``aload_document`` — which need no database.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import (
|
||||
CitationRegistry,
|
||||
CitationSourceType,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.document_render import (
|
||||
RenderableDocument,
|
||||
RenderablePassage,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import (
|
||||
KBPostgresBackend,
|
||||
render_full_document,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _backend(state: dict) -> KBPostgresBackend:
|
||||
return KBPostgresBackend(search_space_id=1, runtime=SimpleNamespace(state=state))
|
||||
|
||||
|
||||
def test_render_full_document_uses_full_view_and_registers() -> None:
|
||||
registry = CitationRegistry()
|
||||
document = RenderableDocument(
|
||||
title="Launch Notes",
|
||||
source="Slack",
|
||||
passages=[
|
||||
RenderablePassage(
|
||||
content="push to March 10",
|
||||
locator={"document_id": 7, "chunk_id": 880},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
rendered = render_full_document(document, registry)
|
||||
|
||||
assert '<document title="Launch Notes" source="Slack" view="full">' in rendered
|
||||
assert "[1] push to March 10" in rendered
|
||||
entry = registry.resolve(1)
|
||||
assert entry is not None
|
||||
assert entry.locator == {"document_id": 7, "chunk_id": 880}
|
||||
|
||||
|
||||
def test_render_full_document_reuses_search_label() -> None:
|
||||
"""A chunk already registered from search keeps its [n] on a later full read."""
|
||||
registry = CitationRegistry()
|
||||
n = registry.register(
|
||||
CitationSourceType.KB_CHUNK,
|
||||
{"document_id": 7, "chunk_id": 880},
|
||||
{"title": "Launch Notes", "source": "Slack"},
|
||||
)
|
||||
document = RenderableDocument(
|
||||
title="Launch Notes",
|
||||
source="Slack",
|
||||
passages=[
|
||||
RenderablePassage(
|
||||
content="new chunk",
|
||||
locator={"document_id": 7, "chunk_id": 881},
|
||||
),
|
||||
RenderablePassage(
|
||||
content="push to March 10",
|
||||
locator={"document_id": 7, "chunk_id": 880},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
rendered = render_full_document(document, registry)
|
||||
|
||||
assert f"[{n}] push to March 10" in rendered
|
||||
assert "[2] new chunk" in rendered
|
||||
|
||||
|
||||
def test_render_full_document_empty_falls_back_to_notice() -> None:
|
||||
registry = CitationRegistry()
|
||||
document = RenderableDocument(title="Empty", passages=[])
|
||||
|
||||
assert render_full_document(document, registry) == (
|
||||
"(This document has no readable content.)"
|
||||
)
|
||||
|
||||
|
||||
async def test_aload_document_anonymous_upload() -> None:
|
||||
backend = _backend(
|
||||
{
|
||||
"kb_anon_doc": {
|
||||
"path": "/anon_upload.md",
|
||||
"title": "Quarterly Report",
|
||||
"chunks": [
|
||||
{"chunk_id": -1, "content": "revenue grew"},
|
||||
{"chunk_id": -2, "content": "costs fell"},
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
loaded = await backend.aload_document("/anon_upload.md")
|
||||
|
||||
assert loaded is not None
|
||||
document, doc_id = loaded
|
||||
assert doc_id is None
|
||||
assert document.title == "Quarterly Report"
|
||||
assert [p.locator["chunk_id"] for p in document.passages] == [-1, -2]
|
||||
assert all(p.locator["document_id"] == -1 for p in document.passages)
|
||||
assert all(
|
||||
p.source_type is CitationSourceType.ANON_CHUNK for p in document.passages
|
||||
)
|
||||
|
||||
|
||||
async def test_aload_document_unknown_path_returns_none() -> None:
|
||||
backend = _backend({})
|
||||
|
||||
assert await backend.aload_document("/not/under/documents.md") is None
|
||||
|
|
@ -1,689 +0,0 @@
|
|||
"""Unit tests for knowledge_search middleware helpers."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware import knowledge_search as ks
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.document_xml import (
|
||||
build_document_xml as _build_document_xml,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
|
||||
KBSearchPlan,
|
||||
KnowledgePriorityMiddleware,
|
||||
_normalize_optional_date_range,
|
||||
_parse_kb_search_plan_response,
|
||||
_render_recent_conversation,
|
||||
_resolve_search_types,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ── _resolve_search_types ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResolveSearchTypes:
|
||||
def test_returns_none_when_no_inputs(self):
|
||||
assert _resolve_search_types(None, None) is None
|
||||
|
||||
def test_returns_none_when_both_empty(self):
|
||||
assert _resolve_search_types([], []) is None
|
||||
|
||||
def test_includes_legacy_type_for_google_gmail(self):
|
||||
result = _resolve_search_types(["GOOGLE_GMAIL_CONNECTOR"], None)
|
||||
assert "GOOGLE_GMAIL_CONNECTOR" in result
|
||||
assert "COMPOSIO_GMAIL_CONNECTOR" in result
|
||||
|
||||
def test_includes_legacy_type_for_google_drive(self):
|
||||
result = _resolve_search_types(None, ["GOOGLE_DRIVE_FILE"])
|
||||
assert "GOOGLE_DRIVE_FILE" in result
|
||||
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in result
|
||||
|
||||
def test_includes_legacy_type_for_google_calendar(self):
|
||||
result = _resolve_search_types(["GOOGLE_CALENDAR_CONNECTOR"], None)
|
||||
assert "GOOGLE_CALENDAR_CONNECTOR" in result
|
||||
assert "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR" in result
|
||||
|
||||
def test_no_legacy_expansion_for_unrelated_types(self):
|
||||
result = _resolve_search_types(["FILE", "NOTE"], None)
|
||||
assert set(result) == {"FILE", "NOTE"}
|
||||
|
||||
def test_combines_connectors_and_document_types(self):
|
||||
result = _resolve_search_types(["FILE"], ["NOTE", "CRAWLED_URL"])
|
||||
assert {"FILE", "NOTE", "CRAWLED_URL"}.issubset(set(result))
|
||||
|
||||
def test_deduplicates(self):
|
||||
result = _resolve_search_types(["FILE", "FILE"], ["FILE"])
|
||||
assert result.count("FILE") == 1
|
||||
|
||||
|
||||
# ── _build_document_xml ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBuildDocumentXml:
|
||||
@pytest.fixture
|
||||
def sample_document(self):
|
||||
return {
|
||||
"document_id": 42,
|
||||
"document": {
|
||||
"id": 42,
|
||||
"document_type": "FILE",
|
||||
"title": "Test Doc",
|
||||
"metadata": {"url": "https://example.com"},
|
||||
},
|
||||
"chunks": [
|
||||
{"chunk_id": 101, "content": "First chunk content"},
|
||||
{"chunk_id": 102, "content": "Second chunk content"},
|
||||
{"chunk_id": 103, "content": "Third chunk content"},
|
||||
],
|
||||
}
|
||||
|
||||
def test_contains_document_metadata(self, sample_document):
|
||||
xml = _build_document_xml(sample_document)
|
||||
assert "<document_id>42</document_id>" in xml
|
||||
assert "<document_type>FILE</document_type>" in xml
|
||||
assert "Test Doc" in xml
|
||||
|
||||
def test_contains_chunk_index(self, sample_document):
|
||||
xml = _build_document_xml(sample_document)
|
||||
assert "<chunk_index>" in xml
|
||||
assert "</chunk_index>" in xml
|
||||
assert 'chunk_id="101"' in xml
|
||||
assert 'chunk_id="102"' in xml
|
||||
assert 'chunk_id="103"' in xml
|
||||
|
||||
def test_matched_chunks_flagged_in_index(self, sample_document):
|
||||
xml = _build_document_xml(sample_document, matched_chunk_ids={101, 103})
|
||||
lines = xml.split("\n")
|
||||
for line in lines:
|
||||
if 'chunk_id="101"' in line:
|
||||
assert 'matched="true"' in line
|
||||
if 'chunk_id="102"' in line:
|
||||
assert 'matched="true"' not in line
|
||||
if 'chunk_id="103"' in line:
|
||||
assert 'matched="true"' in line
|
||||
|
||||
def test_chunk_content_in_document_content_section(self, sample_document):
|
||||
xml = _build_document_xml(sample_document)
|
||||
assert "<document_content>" in xml
|
||||
assert "First chunk content" in xml
|
||||
assert "Second chunk content" in xml
|
||||
assert "Third chunk content" in xml
|
||||
|
||||
def test_line_numbers_in_chunk_index_are_accurate(self, sample_document):
|
||||
"""Verify that the line ranges in chunk_index actually point to the right content."""
|
||||
xml = _build_document_xml(sample_document, matched_chunk_ids={101})
|
||||
xml_lines = xml.split("\n")
|
||||
|
||||
for line in xml_lines:
|
||||
if 'chunk_id="101"' in line and "lines=" in line:
|
||||
import re
|
||||
|
||||
m = re.search(r'lines="(\d+)-(\d+)"', line)
|
||||
assert m, f"No lines= attribute found in: {line}"
|
||||
start, _end = int(m.group(1)), int(m.group(2))
|
||||
target_line = xml_lines[start - 1]
|
||||
assert "101" in target_line
|
||||
assert "First chunk content" in target_line
|
||||
break
|
||||
else:
|
||||
pytest.fail("chunk_id=101 entry not found in chunk_index")
|
||||
|
||||
def test_splits_into_lines_correctly(self, sample_document):
|
||||
"""Each chunk occupies exactly one line (no embedded newlines)."""
|
||||
xml = _build_document_xml(sample_document)
|
||||
lines = xml.split("\n")
|
||||
chunk_lines = [
|
||||
line for line in lines if "<![CDATA[" in line and "<chunk" in line
|
||||
]
|
||||
assert len(chunk_lines) == 3
|
||||
|
||||
|
||||
# ── planner parsing / date normalization ───────────────────────────────
|
||||
|
||||
|
||||
class TestPlannerHelpers:
|
||||
def test_parse_kb_search_plan_response_accepts_plain_json(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "ocv meeting decisions summary",
|
||||
"start_date": "2026-03-01",
|
||||
"end_date": "2026-03-31",
|
||||
}
|
||||
)
|
||||
)
|
||||
assert plan.optimized_query == "ocv meeting decisions summary"
|
||||
assert plan.start_date == "2026-03-01"
|
||||
assert plan.end_date == "2026-03-31"
|
||||
|
||||
def test_parse_kb_search_plan_response_accepts_fenced_json(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
"""```json
|
||||
{"optimized_query":"deel founders guide","start_date":null,"end_date":null}
|
||||
```"""
|
||||
)
|
||||
assert plan.optimized_query == "deel founders guide"
|
||||
assert plan.start_date is None
|
||||
assert plan.end_date is None
|
||||
|
||||
def test_normalize_optional_date_range_returns_none_when_absent(self):
|
||||
start_date, end_date = _normalize_optional_date_range(None, None)
|
||||
assert start_date is None
|
||||
assert end_date is None
|
||||
|
||||
def test_normalize_optional_date_range_resolves_single_bound(self):
|
||||
start_date, end_date = _normalize_optional_date_range("2026-03-01", None)
|
||||
assert start_date is not None
|
||||
assert end_date is not None
|
||||
assert start_date.date().isoformat() == "2026-03-01"
|
||||
assert end_date >= start_date
|
||||
|
||||
|
||||
class FakeLLM:
|
||||
def __init__(self, response_text: str):
|
||||
self.response_text = response_text
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def ainvoke(self, messages, config=None):
|
||||
self.calls.append({"messages": messages, "config": config})
|
||||
return AIMessage(content=self.response_text)
|
||||
|
||||
|
||||
class FakeBudgetLLM:
|
||||
def __init__(self, *, max_input_tokens: int):
|
||||
self._max_input_tokens_value = max_input_tokens
|
||||
|
||||
def _get_max_input_tokens(self) -> int:
|
||||
return self._max_input_tokens_value
|
||||
|
||||
def _count_tokens(self, messages) -> int:
|
||||
# Deterministic, simple proxy for tests: count characters as tokens.
|
||||
return sum(len(msg.get("content", "")) for msg in messages)
|
||||
|
||||
|
||||
class TestKnowledgePriorityMiddlewarePlanner:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _disable_planner_runnable(self, monkeypatch):
|
||||
# ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the
|
||||
# planner Runnable path is enabled) calls ``.bind()`` on the LLM,
|
||||
# which the mock does not implement. Pin the flag off so the
|
||||
# planner falls through to the legacy ``self.llm.ainvoke`` path
|
||||
# these tests assert against (``llm.calls[0]["config"]``).
|
||||
monkeypatch.setenv("SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "false")
|
||||
|
||||
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
|
||||
messages = [
|
||||
HumanMessage(content="old user context " * 40),
|
||||
AIMessage(content="old assistant answer " * 35),
|
||||
HumanMessage(content="recent user context " * 20),
|
||||
AIMessage(content="recent assistant answer " * 18),
|
||||
HumanMessage(content="latest question"),
|
||||
]
|
||||
|
||||
rendered = _render_recent_conversation(
|
||||
messages,
|
||||
llm=FakeBudgetLLM(max_input_tokens=900),
|
||||
user_text="latest question",
|
||||
)
|
||||
|
||||
assert "recent user context" in rendered
|
||||
assert "recent assistant answer" in rendered
|
||||
assert "latest question" not in rendered
|
||||
assert rendered.index("recent user context") < rendered.index(
|
||||
"recent assistant answer"
|
||||
)
|
||||
|
||||
def test_render_recent_conversation_falls_back_to_legacy_without_budgeting(self):
|
||||
messages = [
|
||||
HumanMessage(content="message one"),
|
||||
AIMessage(content="message two"),
|
||||
HumanMessage(content="latest question"),
|
||||
]
|
||||
|
||||
rendered = _render_recent_conversation(
|
||||
messages,
|
||||
llm=None,
|
||||
user_text="latest question",
|
||||
)
|
||||
|
||||
assert "user: message one" in rendered
|
||||
assert "assistant: message two" in rendered
|
||||
assert "latest question" not in rendered
|
||||
|
||||
async def test_middleware_uses_optimized_query_and_dates(self, monkeypatch):
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "ocv meeting decisions action items",
|
||||
"start_date": "2026-03-01",
|
||||
"end_date": "2026-03-31",
|
||||
}
|
||||
)
|
||||
)
|
||||
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=37)
|
||||
|
||||
result = await middleware.abefore_agent(
|
||||
{
|
||||
"messages": [
|
||||
HumanMessage(content="what happened in our OCV meeting last month?")
|
||||
]
|
||||
},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert captured["query"] == "ocv meeting decisions action items"
|
||||
assert captured["start_date"] is not None
|
||||
assert captured["end_date"] is not None
|
||||
assert captured["start_date"].date().isoformat() == "2026-03-01"
|
||||
assert captured["end_date"].date().isoformat() == "2026-03-31"
|
||||
assert llm.calls[0]["config"] == {"tags": ["surfsense:internal"]}
|
||||
|
||||
async def test_middleware_falls_back_when_planner_returns_invalid_json(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgePriorityMiddleware(
|
||||
llm=FakeLLM("not json"),
|
||||
search_space_id=37,
|
||||
)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert captured["query"] == "summarize founders guide by deel"
|
||||
assert captured["start_date"] is None
|
||||
assert captured["end_date"] is None
|
||||
|
||||
async def test_middleware_passes_none_dates_when_planner_returns_nulls(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgePriorityMiddleware(
|
||||
llm=FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "deel founders guide summary",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
}
|
||||
)
|
||||
),
|
||||
search_space_id=37,
|
||||
)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert captured["query"] == "deel founders guide summary"
|
||||
assert captured["start_date"] is None
|
||||
assert captured["end_date"] is None
|
||||
|
||||
async def test_middleware_routes_to_recency_browse_when_flagged(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""When the planner sets is_recency_query=true, browse_recent_documents
|
||||
is called instead of search_knowledge_base."""
|
||||
browse_captured: dict = {}
|
||||
search_called = False
|
||||
|
||||
async def fake_browse_recent_documents(**kwargs):
|
||||
browse_captured.update(kwargs)
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
nonlocal search_called
|
||||
search_called = True
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"browse_recent_documents",
|
||||
fake_browse_recent_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "latest uploaded file",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": True,
|
||||
}
|
||||
)
|
||||
)
|
||||
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42)
|
||||
|
||||
result = await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what's my latest file?")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert browse_captured["search_space_id"] == 42
|
||||
assert not search_called
|
||||
|
||||
async def test_middleware_uses_hybrid_search_when_not_recency(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""When is_recency_query is false (default), hybrid search is used."""
|
||||
search_captured: dict = {}
|
||||
browse_called = False
|
||||
|
||||
async def fake_browse_recent_documents(**kwargs):
|
||||
nonlocal browse_called
|
||||
browse_called = True
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
search_captured.update(kwargs)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"browse_recent_documents",
|
||||
fake_browse_recent_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "quarterly revenue report analysis",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": False,
|
||||
}
|
||||
)
|
||||
)
|
||||
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="find the quarterly revenue report")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert search_captured["query"] == "quarterly revenue report analysis"
|
||||
assert not browse_called
|
||||
|
||||
|
||||
# ── KBSearchPlan schema ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestKBSearchPlanSchema:
|
||||
def test_is_recency_query_defaults_to_false(self):
|
||||
plan = KBSearchPlan(optimized_query="test query")
|
||||
assert plan.is_recency_query is False
|
||||
|
||||
def test_is_recency_query_parses_true(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "latest uploaded file",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": True,
|
||||
}
|
||||
)
|
||||
)
|
||||
assert plan.is_recency_query is True
|
||||
assert plan.optimized_query == "latest uploaded file"
|
||||
|
||||
def test_missing_is_recency_query_defaults_to_false(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "meeting notes",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
}
|
||||
)
|
||||
)
|
||||
assert plan.is_recency_query is False
|
||||
|
||||
|
||||
# ── mentioned_document_ids cross-turn drain ────────────────────────────
|
||||
|
||||
|
||||
class TestKnowledgePriorityMentionDrain:
|
||||
"""Regression tests for the cross-turn ``mentioned_document_ids`` drain.
|
||||
|
||||
The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware`
|
||||
instance across turns of the same thread. ``mentioned_document_ids``
|
||||
can therefore enter the middleware via two paths:
|
||||
|
||||
1. The constructor closure (``__init__(mentioned_document_ids=...)``) —
|
||||
seeded by the cache-miss build on turn 1.
|
||||
2. ``runtime.context.mentioned_document_ids`` — supplied freshly per
|
||||
turn by the streaming task.
|
||||
|
||||
Without the drain fix, an empty ``runtime.context.mentioned_document_ids``
|
||||
on turn 2 would fall through to the closure (because ``[]`` is falsy in
|
||||
Python) and replay turn 1's mentions. This class pins down the
|
||||
correct behaviour: the runtime path is authoritative even when empty,
|
||||
and the closure is drained the first time the runtime path fires so
|
||||
no later turn can ever resurrect stale state.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_runtime(mention_ids: list[int]):
|
||||
"""Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
return SimpleNamespace(
|
||||
context=SimpleNamespace(mentioned_document_ids=mention_ids),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _planner_llm() -> "FakeLLM":
|
||||
# Planner returns a stable, non-recency plan so we always land in
|
||||
# the hybrid-search branch (where ``fetch_mentioned_documents`` is
|
||||
# invoked alongside the main search).
|
||||
return FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "follow up question",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": False,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch):
|
||||
"""Turn 1 with mentions in BOTH closure and runtime context: the
|
||||
runtime path wins AND the closure is drained so a future turn
|
||||
cannot replay it.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgePriorityMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[1, 2, 3],
|
||||
)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what is in those docs?")]},
|
||||
runtime=self._make_runtime([1, 2, 3]),
|
||||
)
|
||||
|
||||
assert fetched_ids == [[1, 2, 3]], (
|
||||
"runtime.context mentions must be the source of truth on turn 1"
|
||||
)
|
||||
assert middleware.mentioned_document_ids == [], (
|
||||
"closure must be drained the first time the runtime path fires "
|
||||
"so no later turn can replay stale mentions"
|
||||
)
|
||||
|
||||
async def test_empty_runtime_context_does_not_replay_closure_mentions(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Regression: turn 2 with NO mentions must not surface turn 1's
|
||||
mentions from the constructor closure.
|
||||
|
||||
Before the fix, ``if ctx_mentions:`` treated an empty list as
|
||||
absent and fell through to ``elif self.mentioned_document_ids:``,
|
||||
replaying turn 1's mentions. This test pins down the corrected
|
||||
behaviour.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
# Simulate a cached middleware instance whose closure was seeded
|
||||
# by a previous turn's cache-miss build (mentions=[1,2,3]).
|
||||
middleware = KnowledgePriorityMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[1, 2, 3],
|
||||
)
|
||||
|
||||
# Turn 2: streaming task supplies an EMPTY mention list (no
|
||||
# mentions on this follow-up turn).
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what about the next steps?")]},
|
||||
runtime=self._make_runtime([]),
|
||||
)
|
||||
|
||||
assert fetched_ids == [], (
|
||||
"fetch_mentioned_documents must NOT be called when the runtime "
|
||||
"context says there are no mentions for this turn"
|
||||
)
|
||||
|
||||
async def test_legacy_path_fires_only_when_runtime_context_absent(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Backward-compat: if a caller doesn't supply runtime.context (old
|
||||
non-streaming code path), the closure-injected mentions are still
|
||||
honoured exactly once and then drained.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ks,
|
||||
"search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgePriorityMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[7, 8],
|
||||
)
|
||||
|
||||
# First call: no runtime → legacy path uses the closure.
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="initial question")]},
|
||||
runtime=None,
|
||||
)
|
||||
# Second call: still no runtime — closure already drained, so no replay.
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="follow up")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert fetched_ids == [[7, 8]], (
|
||||
"legacy path must honour the closure exactly once and then drain it"
|
||||
)
|
||||
assert middleware.mentioned_document_ids == []
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
"""Behavior tests for finalize-time citation resolution.
|
||||
|
||||
The finalize step is the single server-side seam that turns the model's bare
|
||||
``[n]`` ordinals into renderer-ready ``[citation:<payload>]`` markers, using the
|
||||
registry captured from the run's final state. These tests pin that contract:
|
||||
known ordinals resolve, unknown ones drop, foreign markers survive, and a
|
||||
serialized (dict) registry is accepted just like a live one.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import (
|
||||
CitationRegistry,
|
||||
CitationSourceType,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.assistant_finalize import _resolve_citations
|
||||
|
||||
|
||||
def _registry_with_chunk(chunk_id: int = 42) -> CitationRegistry:
|
||||
registry = CitationRegistry()
|
||||
registry.register(
|
||||
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
|
||||
)
|
||||
return registry
|
||||
|
||||
|
||||
def _text(value: str) -> list[dict]:
|
||||
return [{"type": "text", "text": value}]
|
||||
|
||||
|
||||
def test_known_ordinal_resolves_to_chunk_marker():
|
||||
payload = _resolve_citations(
|
||||
_text("Launch is March 10 [1]."), _registry_with_chunk(42)
|
||||
)
|
||||
|
||||
assert payload[0]["text"] == "Launch is March 10 [citation:42]."
|
||||
|
||||
|
||||
def test_unknown_ordinal_is_dropped():
|
||||
payload = _resolve_citations(
|
||||
_text("Unsupported claim [9]."), _registry_with_chunk(42)
|
||||
)
|
||||
|
||||
assert payload[0]["text"] == "Unsupported claim ."
|
||||
|
||||
|
||||
def test_foreign_citation_marker_is_preserved():
|
||||
payload = _resolve_citations(
|
||||
_text("From the web [citation:https://example.com]."),
|
||||
_registry_with_chunk(42),
|
||||
)
|
||||
|
||||
assert payload[0]["text"] == "From the web [citation:https://example.com]."
|
||||
|
||||
|
||||
def test_serialized_registry_is_accepted():
|
||||
serialized = _registry_with_chunk(7).model_dump()
|
||||
|
||||
payload = _resolve_citations(_text("See [1]."), serialized)
|
||||
|
||||
assert payload[0]["text"] == "See [citation:7]."
|
||||
|
||||
|
||||
def test_empty_registry_leaves_text_untouched():
|
||||
payload = _resolve_citations(_text("No sources here [1]."), CitationRegistry())
|
||||
|
||||
assert payload[0]["text"] == "No sources here [1]."
|
||||
|
||||
|
||||
def test_missing_registry_is_a_noop():
|
||||
payload = _resolve_citations(_text("Nothing to resolve [1]."), None)
|
||||
|
||||
assert payload[0]["text"] == "Nothing to resolve [1]."
|
||||
|
||||
|
||||
def test_non_text_parts_are_left_alone():
|
||||
parts = [
|
||||
{"type": "tool_call", "name": "search_knowledge_base", "args": {"q": "[1]"}},
|
||||
{"type": "text", "text": "Result [1]."},
|
||||
]
|
||||
|
||||
payload = _resolve_citations(parts, _registry_with_chunk(5))
|
||||
|
||||
assert payload[0]["args"]["q"] == "[1]"
|
||||
assert payload[1]["text"] == "Result [citation:5]."
|
||||
Loading…
Add table
Add a link
Reference in a new issue