Merge pull request #1539 from CREDO23/improve-chat-agent-context-and-citations

[FEAT] Unified [n] citation registry for KB + web, pull-based retrieval
This commit is contained in:
Rohan Verma 2026-06-25 13:34:52 -07:00 committed by GitHub
commit 94fdb8a113
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
160 changed files with 4097 additions and 5238 deletions

View file

@ -0,0 +1,41 @@
"""Tests for connector pointer field selection."""
from __future__ import annotations
import pytest
from app.agents.chat.runtime.references.connectors import connector_pointer_fields
pytestmark = pytest.mark.unit
def test_prefers_chip_account_and_type() -> None:
label, provider = connector_pointer_fields(
account_name="work@acme.com",
connector_type="Gmail",
fallback_name="My Gmail",
)
assert (label, provider) == ("work@acme.com", "Gmail")
def test_falls_back_to_stored_name_when_account_missing() -> None:
label, provider = connector_pointer_fields(
account_name=None,
connector_type="Slack",
fallback_name="Acme Slack",
)
assert label == "Acme Slack"
assert provider == "Slack"
def test_provider_is_none_when_unknown() -> None:
label, provider = connector_pointer_fields(
account_name="a@b.com",
connector_type=None,
fallback_name=None,
)
assert label == "a@b.com"
assert provider is None

View file

@ -0,0 +1,21 @@
"""Tests for folder pointer-path shaping."""
from __future__ import annotations
import pytest
from app.agents.chat.runtime.references.folders import folder_pointer_path
pytestmark = pytest.mark.unit
def test_adds_trailing_slash_so_path_reads_as_directory() -> None:
assert folder_pointer_path(7, {7: "/documents/Specs"}) == "/documents/Specs/"
def test_keeps_existing_trailing_slash() -> None:
assert folder_pointer_path(7, {7: "/documents/Specs/"}) == "/documents/Specs/"
def test_unknown_folder_falls_back_to_documents_root() -> None:
assert folder_pointer_path(99, {}) == "/documents/"

View file

@ -0,0 +1,93 @@
"""Tests for reference pointer rendering."""
from __future__ import annotations
import pytest
from app.agents.chat.runtime.references import (
ChatReference,
ConnectorReference,
DocumentReference,
FolderReference,
render_reference_pointers,
)
pytestmark = pytest.mark.unit
def test_returns_none_when_no_references() -> None:
assert render_reference_pointers([]) is None
def test_wraps_block_and_keeps_reference_order() -> None:
block = render_reference_pointers(
[
DocumentReference(entity_id=42, label="Q3 Notes", path="/documents/q3.xml"),
ChatReference(entity_id=5, label="Pricing"),
]
)
assert block is not None
assert block.startswith("<referenced_this_turn>")
assert block.endswith("</referenced_this_turn>")
assert block.index("document 42") < block.index("chat 5")
def test_document_with_path_shows_title_and_path() -> None:
block = render_reference_pointers(
[
DocumentReference(
entity_id=42,
label="Q3 Launch Notes",
path="/documents/Launch/Q3.xml",
)
]
)
assert block is not None
assert '- document 42 — "Q3 Launch Notes" (/documents/Launch/Q3.xml)' in block
def test_folder_with_path_renders_with_folder_kind() -> None:
block = render_reference_pointers(
[FolderReference(entity_id=7, label="Specs", path="/documents/Specs/")]
)
assert block is not None
assert '- folder 7 — "Specs" (/documents/Specs/)' in block
def test_connector_shows_provider_and_account() -> None:
block = render_reference_pointers(
[ConnectorReference(entity_id=12, label="work@acme.com", provider="Gmail")]
)
assert block is not None
assert "- connector 12 — Gmail (work@acme.com)" in block
def test_connector_without_provider_falls_back_to_label() -> None:
block = render_reference_pointers(
[ConnectorReference(entity_id=12, label="work@acme.com")]
)
assert block is not None
assert "- connector 12 — work@acme.com" in block
def test_chat_shows_quoted_title() -> None:
block = render_reference_pointers(
[ChatReference(entity_id=5, label="Pricing debate")]
)
assert block is not None
assert '- chat 5 — "Pricing debate"' in block
def test_label_whitespace_is_collapsed_to_one_line() -> None:
block = render_reference_pointers(
[DocumentReference(entity_id=1, label="line one\nline two", path="/d.xml")]
)
assert block is not None
assert '- document 1 — "line one line two"' in block

View file

@ -0,0 +1,93 @@
"""Tests for the shared ``web_search`` tool's citable-result adaptation.
The tool's network path (SearXNG + live connectors) is out of scope here; these
cover the pure mapping from raw web results to renderable, citable documents and
the end-to-end registration of ``WEB_RESULT`` ``[n]`` labels.
"""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.agents.chat.multi_agent_chat.shared.document_render import render_web_results
from app.agents.chat.shared.tools.web_search import (
_to_renderable_web_documents,
_web_source_label,
)
pytestmark = pytest.mark.unit
def _raw_result(url: str, title: str, content: str) -> dict:
return {
"document": {"title": title, "metadata": {"url": url}},
"content": content,
}
def test_web_source_label_strips_scheme_and_www() -> None:
assert _web_source_label("https://www.example.com/path") == "Web · example.com"
assert _web_source_label("http://news.site.org/a/b") == "Web · news.site.org"
assert _web_source_label("") == "Web"
def test_adapter_maps_each_result_to_one_web_passage() -> None:
docs = _to_renderable_web_documents(
[
_raw_result("https://a.com/x", "Alpha", "alpha body"),
_raw_result("https://b.com/y", "Beta", "beta body"),
]
)
assert [d.title for d in docs] == ["Alpha", "Beta"]
passages = [p for d in docs for p in d.passages]
assert all(p.source_type is CitationSourceType.WEB_RESULT for p in passages)
assert passages[0].locator == {"url": "https://a.com/x"}
assert passages[0].content == "alpha body"
def test_adapter_skips_results_without_url_or_content() -> None:
docs = _to_renderable_web_documents(
[
_raw_result("", "No URL", "has content"),
_raw_result("https://c.com/z", "Empty", " "),
_raw_result("https://d.com/w", "Good", "real content"),
]
)
assert [d.title for d in docs] == ["Good"]
def test_adapter_truncates_on_char_budget() -> None:
big = "x" * 30
docs = _to_renderable_web_documents(
[
_raw_result("https://a.com", "A", big),
_raw_result("https://b.com", "B", big),
_raw_result("https://c.com", "C", big),
],
max_chars=50,
)
# First fits (30), second crosses 50 and stops the loop.
assert [d.title for d in docs] == ["A"]
def test_end_to_end_registers_web_results_for_citation() -> None:
registry = CitationRegistry()
docs = _to_renderable_web_documents(
[_raw_result("https://example.com/a", "Example", "the answer is 42")]
)
block = render_web_results(docs, registry)
assert block is not None
assert "[1] the answer is 42" in block
entry = registry.resolve(1)
assert entry is not None
assert entry.source_type is CitationSourceType.WEB_RESULT
assert entry.locator == {"url": "https://example.com/a"}

View file

@ -0,0 +1,49 @@
"""Tests for citation-entry → frontend payload mapping."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations.markers import (
to_frontend_payload,
)
from app.agents.chat.multi_agent_chat.shared.citations.models import (
CitationEntry,
CitationSourceType,
)
pytestmark = pytest.mark.unit
def _entry(source_type: CitationSourceType, locator: dict) -> CitationEntry:
return CitationEntry(n=1, source_type=source_type, locator=locator)
def test_kb_chunk_maps_to_chunk_id() -> None:
entry = _entry(CitationSourceType.KB_CHUNK, {"chunk_id": 42, "document_id": 7})
assert to_frontend_payload(entry) == "42"
def test_anon_chunk_keeps_negative_id() -> None:
entry = _entry(CitationSourceType.ANON_CHUNK, {"chunk_id": -3})
assert to_frontend_payload(entry) == "-3"
def test_web_result_maps_to_url() -> None:
entry = _entry(CitationSourceType.WEB_RESULT, {"url": "https://example.com/a"})
assert to_frontend_payload(entry) == "https://example.com/a"
def test_not_yet_renderable_kind_is_dropped() -> None:
entry = _entry(CitationSourceType.CHAT_TURN, {"thread_id": 1, "turn": 2})
assert to_frontend_payload(entry) is None
def test_missing_locator_field_is_dropped() -> None:
entry = _entry(CitationSourceType.KB_CHUNK, {})
assert to_frontend_payload(entry) is None

View file

@ -0,0 +1,113 @@
"""Tests for rewriting model ``[n]`` ordinals into frontend citation markers."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations.models import CitationSourceType
from app.agents.chat.multi_agent_chat.shared.citations.normalizer import (
normalize_citations,
)
from app.agents.chat.multi_agent_chat.shared.citations.registry import CitationRegistry
pytestmark = pytest.mark.unit
def _registry_with_chunks(*chunk_ids: int) -> CitationRegistry:
registry = CitationRegistry()
for chunk_id in chunk_ids:
registry.register(CitationSourceType.KB_CHUNK, {"chunk_id": chunk_id})
return registry
def test_single_ordinal_is_rewritten() -> None:
registry = _registry_with_chunks(42)
assert normalize_citations("We shipped it [1].", registry) == (
"We shipped it [citation:42]."
)
def test_adjacent_brackets_are_each_rewritten() -> None:
registry = _registry_with_chunks(42, 7)
assert normalize_citations("Both agree [1][2].", registry) == (
"Both agree [citation:42][citation:7]."
)
def test_comma_separated_brackets_are_each_rewritten() -> None:
registry = _registry_with_chunks(42, 7)
assert normalize_citations("Both agree [1], [2].", registry) == (
"Both agree [citation:42], [citation:7]."
)
def test_unknown_ordinal_is_dropped() -> None:
registry = _registry_with_chunks(42)
assert normalize_citations("Maybe [9] is real.", registry) == "Maybe is real."
def test_unknown_ordinal_among_known_is_dropped() -> None:
registry = _registry_with_chunks(42)
assert normalize_citations("See [1][9].", registry) == "See [citation:42]."
def test_web_result_rewrites_to_url() -> None:
registry = CitationRegistry()
registry.register(CitationSourceType.WEB_RESULT, {"url": "https://example.com"})
assert normalize_citations("Per the docs [1].", registry) == (
"Per the docs [citation:https://example.com]."
)
def test_word_glued_citation_is_rewritten() -> None:
# The model frequently writes citations glued to the preceding word
# (``docs[1]``); these must still resolve to a marker, not leak as raw text.
registry = _registry_with_chunks(42)
assert normalize_citations("verifying against docs[1].", registry) == (
"verifying against docs[citation:42]."
)
def test_word_glued_unknown_ordinal_drops() -> None:
# A glued ordinal that doesn't resolve drops harmlessly (no broken marker,
# no raw ``[n]`` leak) rather than being preserved as array-index syntax.
registry = _registry_with_chunks(42)
assert normalize_citations("see notes[9] later", registry) == "see notes later"
def test_array_index_inside_code_is_left_alone() -> None:
# Genuine array/index syntax is protected by the code-region carve-out.
registry = _registry_with_chunks(42)
assert normalize_citations("Read `arr[1]` carefully.", registry) == (
"Read `arr[1]` carefully."
)
def test_ordinals_inside_inline_code_are_untouched() -> None:
registry = _registry_with_chunks(42)
assert normalize_citations("Use `list[1]` here [1].", registry) == (
"Use `list[1]` here [citation:42]."
)
def test_ordinals_inside_fenced_code_are_untouched() -> None:
registry = _registry_with_chunks(42)
text = "Before [1].\n```\nx = a[1]\n```\nAfter [1]."
assert normalize_citations(text, registry) == (
"Before [citation:42].\n```\nx = a[1]\n```\nAfter [citation:42]."
)
def test_empty_text_is_returned_unchanged() -> None:
assert normalize_citations("", _registry_with_chunks(42)) == ""

View file

@ -0,0 +1,174 @@
"""Unit tests for the citation registry spine."""
from __future__ import annotations
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
make_key,
)
def test_register_assigns_monotonic_labels() -> None:
registry = CitationRegistry()
first = registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 42, "chunk_id": 880}
)
second = registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 42, "chunk_id": 881}
)
assert (first, second) == (1, 2)
assert registry.next_n == 3
def test_register_is_find_or_create_for_same_unit() -> None:
registry = CitationRegistry()
locator = {"document_id": 42, "chunk_id": 880}
first = registry.register(CitationSourceType.KB_CHUNK, locator)
again = registry.register(CitationSourceType.KB_CHUNK, locator)
assert first == again == 1
assert len(registry.by_n) == 1
assert registry.next_n == 2
def test_dedup_is_insensitive_to_locator_key_order() -> None:
registry = CitationRegistry()
first = registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 42, "chunk_id": 880}
)
reordered = registry.register(
CitationSourceType.KB_CHUNK, {"chunk_id": 880, "document_id": 42}
)
assert first == reordered
def test_same_locator_values_across_types_do_not_collide() -> None:
registry = CitationRegistry()
chunk = registry.register(CitationSourceType.KB_CHUNK, {"id": 7})
chat = registry.register(CitationSourceType.CHAT_TURN, {"id": 7})
assert chunk != chat
def test_resolve_returns_entry_with_locator_and_display() -> None:
registry = CitationRegistry()
n = registry.register(
CitationSourceType.WEB_RESULT,
{"url": "https://example.com"},
{"title": "Example"},
)
entry = registry.resolve(n)
assert entry is not None
assert entry.n == n
assert entry.source_type is CitationSourceType.WEB_RESULT
assert entry.locator == {"url": "https://example.com"}
assert entry.display == {"title": "Example"}
def test_resolve_unknown_label_returns_none() -> None:
registry = CitationRegistry()
assert registry.resolve(999) is None
def test_registry_round_trips_through_serialization() -> None:
registry = CitationRegistry()
registry.register(
CitationSourceType.KB_CHUNK,
{"document_id": 42, "chunk_id": 880},
{"title": "Q3 Launch Notes"},
)
restored = CitationRegistry.model_validate(registry.model_dump())
entry = restored.resolve(1)
assert entry is not None
assert entry.source_type is CitationSourceType.KB_CHUNK
assert restored.next_n == registry.next_n
def test_make_key_is_stable_and_type_prefixed() -> None:
key_a = make_key(CitationSourceType.KB_CHUNK, {"document_id": 42, "chunk_id": 880})
key_b = make_key(CitationSourceType.KB_CHUNK, {"chunk_id": 880, "document_id": 42})
assert key_a == key_b
assert key_a.startswith("kb_chunk|")
def _kb(registry: CitationRegistry, chunk_id: int) -> int:
return registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
)
def test_merge_unions_disjoint_registries_preserving_labels() -> None:
left = CitationRegistry()
_kb(left, 10) # [1]
_kb(left, 11) # [2]
# A branch that forked from `left`, then registered its own chunk at [3].
right = left.model_copy(deep=True)
third = _kb(right, 12) # [3]
assert third == 3
merged = left.merge(right)
assert merged.resolve(1).locator["chunk_id"] == 10
assert merged.resolve(2).locator["chunk_id"] == 11
assert merged.resolve(3).locator["chunk_id"] == 12
assert merged.next_n == 4
def test_merge_keeps_one_label_for_a_shared_source() -> None:
left = CitationRegistry()
_kb(left, 10) # [1]
right = CitationRegistry()
_kb(right, 10) # also [1], same source
merged = left.merge(right)
assert len(merged.by_n) == 1
assert merged.resolve(1).locator["chunk_id"] == 10
assert merged.next_n == 2
def test_merge_remints_on_collision_without_losing_sources() -> None:
# Two branches forked from the same base [1], each minting a *different*
# source at [2]. Merge must keep both sources, re-minting one.
base = CitationRegistry()
_kb(base, 10) # [1]
left = base.model_copy(deep=True)
_kb(left, 11) # [2] -> chunk 11
right = base.model_copy(deep=True)
_kb(right, 12) # [2] -> chunk 12 (collision)
merged = left.merge(right)
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
assert chunk_ids == {10, 11, 12}
assert merged.resolve(2).locator["chunk_id"] == 11 # left wins the slot
assert merged.resolve(3).locator["chunk_id"] == 12 # right re-minted
assert merged.next_n == 4
def test_merge_does_not_mutate_inputs() -> None:
left = CitationRegistry()
_kb(left, 10)
right = CitationRegistry()
_kb(right, 11)
left.merge(right)
assert list(left.by_n) == [1]
assert list(right.by_n) == [1]

View file

@ -0,0 +1,152 @@
"""Tests for the shared ``render_document`` (one ``<document>`` block)."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.agents.chat.multi_agent_chat.shared.document_render import (
RenderableDocument,
RenderablePassage,
render_document,
)
pytestmark = pytest.mark.unit
def _document(
document_id: int,
title: str,
chunk_ids: list[int],
*,
source: str | None = None,
) -> RenderableDocument:
return RenderableDocument(
title=title,
source=source,
passages=[
RenderablePassage(
content=f"text {cid}",
locator={"document_id": document_id, "chunk_id": cid},
)
for cid in chunk_ids
],
)
def test_returns_none_when_no_passages() -> None:
registry = CitationRegistry()
assert (
render_document(_document(1, "Empty", []), view="excerpt", registry=registry)
is None
)
def test_excerpt_open_and_close_tags() -> None:
registry = CitationRegistry()
block = render_document(
_document(1, "Q3 Launch Notes", [880], source="Slack · #launch"),
view="excerpt",
registry=registry,
)
assert block is not None
assert block.startswith(
'<document title="Q3 Launch Notes" source="Slack · #launch" view="excerpt">'
)
assert block.endswith("</document>")
def test_full_view_renders_view_attribute() -> None:
registry = CitationRegistry()
block = render_document(_document(1, "Doc", [880]), view="full", registry=registry)
assert block is not None
assert '<document title="Doc" view="full">' in block
def test_source_attribute_omitted_when_absent() -> None:
registry = CitationRegistry()
block = render_document(
_document(1, "Plain", [1]), view="excerpt", registry=registry
)
assert block is not None
assert block.startswith('<document title="Plain" view="excerpt">')
def test_registers_passages_with_chunk_locators() -> None:
registry = CitationRegistry()
render_document(
_document(1, "Doc", [880], source="Slack"),
view="excerpt",
registry=registry,
)
entry = registry.resolve(1)
assert entry is not None
assert entry.source_type is CitationSourceType.KB_CHUNK
assert entry.locator == {"document_id": 1, "chunk_id": 880}
assert entry.display == {"title": "Doc", "source": "Slack"}
def test_passages_get_monotonic_labels() -> None:
registry = CitationRegistry()
block = render_document(
_document(1, "Doc", [880, 881]), view="excerpt", registry=registry
)
assert block is not None
assert " [1] text 880" in block
assert " [2] text 881" in block
def test_multiline_passage_indents_under_label() -> None:
registry = CitationRegistry()
document = RenderableDocument(
title="Doc",
passages=[
RenderablePassage(
content="line one\nline two",
locator={"document_id": 1, "chunk_id": 5},
)
],
)
block = render_document(document, view="excerpt", registry=registry)
assert block is not None
assert " [1] line one\n line two" in block
def test_attribute_values_are_escaped() -> None:
registry = CitationRegistry()
block = render_document(
_document(1, 'A & B <c> "d"', [1], source="x & y"),
view="excerpt",
registry=registry,
)
assert block is not None
assert 'title="A &amp; B &lt;c&gt; &quot;d&quot;"' in block
assert 'source="x &amp; y"' in block
def test_same_passage_reuses_label_across_calls() -> None:
registry = CitationRegistry()
document = _document(1, "Doc", [880])
render_document(document, view="excerpt", registry=registry)
render_document(document, view="full", registry=registry)
assert registry.next_n == 2

View file

@ -0,0 +1,94 @@
"""Tests for the ``<retrieved_context>`` wrapper around excerpt documents."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
from app.agents.chat.multi_agent_chat.shared.document_render import (
RenderableDocument,
RenderablePassage,
render_search_context,
)
pytestmark = pytest.mark.unit
def _document(
document_id: int,
title: str,
chunk_ids: list[int],
*,
source: str | None = None,
) -> RenderableDocument:
return RenderableDocument(
title=title,
source=source,
passages=[
RenderablePassage(
content=f"text {cid}",
locator={"document_id": document_id, "chunk_id": cid},
)
for cid in chunk_ids
],
)
def test_returns_none_when_nothing_to_show() -> None:
registry = CitationRegistry()
assert render_search_context([], registry) is None
assert render_search_context([_document(1, "Empty", [])], registry) is None
def test_assigns_monotonic_labels_across_documents() -> None:
registry = CitationRegistry()
block = render_search_context(
[
_document(1, "Q3 Launch Notes", [880, 881], source="Slack"),
_document(2, "Timeline", [12], source="Notion"),
],
registry,
)
assert block is not None
assert "[1] text 880" in block
assert "[2] text 881" in block
assert "[3] text 12" in block
def test_wraps_in_retrieved_context_and_teaches_excerpt_and_citation() -> None:
registry = CitationRegistry()
block = render_search_context([_document(1, "Doc", [1])], registry)
assert block is not None
assert block.startswith("<retrieved_context>")
assert block.endswith("</retrieved_context>")
assert "excerpt view" in block
assert "Cite a chunk with its [n]." in block
def test_documents_render_as_excerpt_blocks() -> None:
registry = CitationRegistry()
block = render_search_context(
[_document(1, "Q3", [1], source="Slack · #launch")], registry
)
assert block is not None
assert '<document title="Q3" source="Slack · #launch" view="excerpt">' in block
assert "</document>" in block
def test_same_passage_reuses_label_across_calls() -> None:
registry = CitationRegistry()
document = _document(1, "Doc", [880])
render_search_context([document], registry)
block = render_search_context([document], registry)
assert block is not None
assert "[1] text 880" in block
assert registry.next_n == 2

View file

@ -0,0 +1,35 @@
"""Tests for building a document's source label."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.document_render import source_label
pytestmark = pytest.mark.unit
def test_known_type_uses_friendly_name() -> None:
assert source_label("SLACK_CONNECTOR", {}) == "Slack"
def test_unmapped_type_is_prettified() -> None:
assert source_label("GOOGLE_DRIVE_FILE", {}) == "Google Drive"
def test_url_host_is_appended_and_www_stripped() -> None:
label = source_label("CRAWLED_URL", {"url": "https://www.docs.python.org/3/"})
assert label == "Web · docs.python.org"
def test_host_only_when_type_unknown() -> None:
assert source_label(None, {"url": "https://example.com/a"}) == "example.com"
def test_returns_none_when_nothing_known() -> None:
assert source_label(None, {}) is None
def test_non_http_url_is_ignored() -> None:
assert source_label("FILE", {"url": "/local/path"}) == "File"

View file

@ -0,0 +1,82 @@
"""Tests for the ``<web_results>`` wrapper around web-result excerpt documents."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.agents.chat.multi_agent_chat.shared.document_render import (
RenderableDocument,
RenderablePassage,
render_web_results,
)
pytestmark = pytest.mark.unit
def _web_doc(url: str, title: str, content: str) -> RenderableDocument:
return RenderableDocument(
title=title,
source=f"Web · {url.split('//', 1)[-1].split('/', 1)[0]}",
passages=[
RenderablePassage(
content=content,
locator={"url": url},
source_type=CitationSourceType.WEB_RESULT,
)
],
)
def test_returns_none_when_nothing_to_show() -> None:
registry = CitationRegistry()
assert render_web_results([], registry) is None
def test_wraps_in_web_results_container() -> None:
registry = CitationRegistry()
block = render_web_results(
[_web_doc("https://example.com/a", "Example", "the answer is 42")],
registry,
)
assert block is not None
assert block.startswith("<web_results>")
assert block.endswith("</web_results>")
assert "cite a result with its [n]" in block
assert '<document title="Example" source="Web · example.com" view="excerpt">' in block
assert "[1] the answer is 42" in block
def test_registers_each_result_as_web_result_with_url_locator() -> None:
registry = CitationRegistry()
render_web_results(
[
_web_doc("https://a.com/x", "A", "alpha"),
_web_doc("https://b.com/y", "B", "beta"),
],
registry,
)
first = registry.resolve(1)
second = registry.resolve(2)
assert first is not None and second is not None
assert first.source_type is CitationSourceType.WEB_RESULT
assert first.locator == {"url": "https://a.com/x"}
assert second.locator == {"url": "https://b.com/y"}
def test_same_url_reuses_label_across_calls() -> None:
registry = CitationRegistry()
doc = _web_doc("https://example.com/a", "Example", "stable fact")
render_web_results([doc], registry)
render_web_results([doc], registry)
assert registry.next_n == 2

View file

@ -0,0 +1,51 @@
"""Tests for mapping a DocumentHit to a renderable document."""
from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.retrieval.adapter import (
to_renderable_document,
)
from app.agents.chat.multi_agent_chat.shared.retrieval.models import (
ChunkHit,
DocumentHit,
)
pytestmark = pytest.mark.unit
def test_maps_identity_source_and_passages() -> None:
hit = DocumentHit(
document_id=42,
title="Q3 Launch Notes",
document_type="SLACK_CONNECTOR",
metadata={},
score=0.9,
chunks=[
ChunkHit(chunk_id=880, content="a", position=4, score=0.9),
ChunkHit(chunk_id=881, content="b", position=7, score=0.5),
],
)
document = to_renderable_document(hit)
assert document.title == "Q3 Launch Notes"
assert document.source == "Slack"
assert [
(p.locator["chunk_id"], p.content) for p in document.passages
] == [(880, "a"), (881, "b")]
assert all(p.locator["document_id"] == 42 for p in document.passages)
def test_document_with_no_chunks_maps_to_no_passages() -> None:
hit = DocumentHit(
document_id=1,
title="Empty",
document_type=None,
metadata={},
score=0.0,
chunks=[],
)
assert to_renderable_document(hit).passages == []

View file

@ -0,0 +1,65 @@
"""Tests for the build_context pipeline (rerank → adapt → render)."""
from __future__ import annotations
from typing import Any
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
from app.agents.chat.multi_agent_chat.shared.retrieval.models import (
ChunkHit,
DocumentHit,
)
from app.agents.chat.multi_agent_chat.shared.retrieval.service import build_context
pytestmark = pytest.mark.unit
def _hit(document_id: int, chunk_id: int) -> DocumentHit:
return DocumentHit(
document_id=document_id,
title=f"Doc {document_id}",
document_type="FILE",
metadata={},
score=1.0 / document_id,
chunks=[ChunkHit(chunk_id=chunk_id, content=f"text {chunk_id}", position=0, score=1.0)],
)
def test_no_hits_renders_nothing() -> None:
assert build_context("q", [], CitationRegistry()) is None
def test_renders_block_and_registers_labels_in_order() -> None:
registry = CitationRegistry()
block = build_context("q", [_hit(1, 880), _hit(2, 12)], registry)
assert block is not None
assert "[1] text 880" in block
assert "[2] text 12" in block
assert registry.resolve(1).locator == {"document_id": 1, "chunk_id": 880}
assert registry.resolve(2).locator == {"document_id": 2, "chunk_id": 12}
class _ReverseReranker:
"""Stand-in reranker that simply reverses document order."""
def rerank_documents(
self, query_text: str, documents: list[dict[str, Any]]
) -> list[dict[str, Any]]:
return list(reversed(documents))
def test_reranker_reorders_documents_before_labeling() -> None:
registry = CitationRegistry()
block = build_context(
"q", [_hit(1, 880), _hit(2, 12)], registry, reranker=_ReverseReranker()
)
assert block is not None
# Reversed: doc 2 now renders first and gets [1].
assert registry.resolve(1).locator == {"document_id": 2, "chunk_id": 12}
assert registry.resolve(2).locator == {"document_id": 1, "chunk_id": 880}

View file

@ -1,295 +0,0 @@
"""Tests for the prompt fragment composer."""
from __future__ import annotations
from datetime import UTC, datetime
import pytest
from app.db import ChatVisibility
from app.prompts.system_prompt_composer.composer import (
ALL_TOOL_NAMES_ORDERED,
compose_system_prompt,
detect_provider_variant,
)
pytestmark = pytest.mark.unit
@pytest.fixture
def fixed_today() -> datetime:
return datetime(2025, 6, 1, 12, 0, tzinfo=UTC)
class TestProviderVariantDetection:
@pytest.mark.parametrize(
"model_name,expected",
[
# GPT-4 family routes to "classic" (autonomous-persistence style)
("openai:gpt-4o-mini", "openai_classic"),
("openai:gpt-4-turbo", "openai_classic"),
# GPT-5 / o-series route to "reasoning" (channel-aware pragmatic)
("openai:gpt-5", "openai_reasoning"),
("openai:o1-preview", "openai_reasoning"),
("openai:o3-mini", "openai_reasoning"),
# Codex family beats reasoning (more specific). Mirrors OpenCode
# ``system.ts`` — ``gpt-*-codex`` gets the code-purist prompt.
("openai:gpt-5-codex", "openai_codex"),
("openai:gpt-codex", "openai_codex"),
("openai:codex-mini", "openai_codex"),
# Anthropic + Google
("anthropic:claude-3-5-sonnet", "anthropic"),
("anthropic/claude-opus-4", "anthropic"),
("google:gemini-2.0-flash", "google"),
("vertex:gemini-1.5-pro", "google"),
# Newly-covered families
("moonshot:kimi-k2", "kimi"),
("openrouter:moonshot/kimi-k2.5", "kimi"),
("xai:grok-2", "grok"),
("openrouter:x-ai/grok-3", "grok"),
("openai:deepseek-v3", "deepseek"),
("deepseek:deepseek-r1", "deepseek"),
# Unknown families fall back to default (no provider block emitted)
("groq:mixtral-8x7b", "default"),
("together:llama-3.1-70b", "default"),
(None, "default"),
("", "default"),
],
)
def test_detection(self, model_name: str | None, expected: str) -> None:
assert detect_provider_variant(model_name) == expected
def test_codex_takes_precedence_over_reasoning(self) -> None:
"""Regression guard: ``gpt-5-codex`` must NOT match the generic
``gpt-5`` reasoning regex first. Codex is the more specialised
prompt and mirrors OpenCode's dispatch order.
"""
from app.prompts.system_prompt_composer.composer import detect_provider_variant
assert detect_provider_variant("openai:gpt-5-codex") == "openai_codex"
assert detect_provider_variant("openai:gpt-5") == "openai_reasoning"
class TestCompose:
def test_default_prompt_has_required_blocks(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(today=fixed_today)
# System instruction wrapper
assert "<system_instruction>" in prompt
assert "</system_instruction>" in prompt
# Date interpolated
assert "2025-06-01" in prompt
# Core policy blocks present
assert "<knowledge_base_only_policy>" in prompt
assert "<tool_routing>" in prompt
assert "<parameter_resolution>" in prompt
assert "<memory_protocol>" in prompt
# Tools
assert "<tools>" in prompt
assert "</tools>" in prompt
# Citations on by default
assert "<citation_instructions>" in prompt
assert "[citation:chunk_id]" in prompt
def test_team_visibility_uses_team_variants(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(
today=fixed_today,
thread_visibility=ChatVisibility.SEARCH_SPACE,
)
# Team-specific phrasing in the agent block
assert "team space" in prompt
# Memory protocol mentions team
assert "team" in prompt
# Should NOT mention the user-only memory phrasing
assert "personal knowledge base" not in prompt
def test_private_visibility_uses_private_variants(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
thread_visibility=ChatVisibility.PRIVATE,
)
assert "personal knowledge base" in prompt
# Should NOT mention the team-specific phrasing about prefixed authors
assert "[DisplayName of the author]" not in prompt
def test_citations_disabled_swaps_block(self, fixed_today: datetime) -> None:
prompt_on = compose_system_prompt(today=fixed_today, citations_enabled=True)
prompt_off = compose_system_prompt(today=fixed_today, citations_enabled=False)
assert "Citations are DISABLED" in prompt_off
assert "Citations are DISABLED" not in prompt_on
assert "[citation:chunk_id]" in prompt_on
def test_enabled_tool_filter_only_includes_listed_tools(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"web_search", "scrape_webpage"},
)
assert "web_search:" in prompt or "- web_search:" in prompt
assert "scrape_webpage:" in prompt or "- scrape_webpage:" in prompt
# Excluded tools should NOT appear in tool listing
assert "generate_podcast:" not in prompt
assert "generate_image:" not in prompt
def test_disabled_tool_note_is_appended(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"web_search"},
disabled_tool_names={"generate_image", "generate_podcast"},
)
assert "DISABLED TOOLS (by user):" in prompt
assert "Generate Image" in prompt
assert "Generate Podcast" in prompt
def test_mcp_routing_block_emits_when_provided(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(
today=fixed_today,
mcp_connector_tools={"My GitLab": ["gitlab_search", "gitlab_create_mr"]},
)
assert "<mcp_tool_routing>" in prompt
assert "My GitLab" in prompt
assert "gitlab_search" in prompt
def test_mcp_routing_block_absent_when_no_servers(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(today=fixed_today, mcp_connector_tools={})
assert "<mcp_tool_routing>" not in prompt
def test_provider_block_renders_when_anthropic(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(
today=fixed_today, model_name="anthropic:claude-3-5-sonnet"
)
assert "<provider_hints>" in prompt
assert "Anthropic" in prompt or "Claude" in prompt
def test_provider_block_absent_for_default(self, fixed_today: datetime) -> None:
prompt = compose_system_prompt(today=fixed_today, model_name="custom:foo")
assert "<provider_hints>" not in prompt
@pytest.mark.parametrize(
"model_name,expected_marker",
[
# Each marker is a unique-ish phrase from the corresponding fragment.
# If a fragment is renamed/rewritten such that the marker is gone,
# update both the fragment and this test deliberately.
("openai:gpt-5-codex", "Codex-class"),
("openai:gpt-5", "OpenAI reasoning model"),
("openai:gpt-4o", "classic OpenAI chat model"),
("anthropic:claude-3-5-sonnet", "Anthropic Claude"),
("google:gemini-2.0-flash", "Google Gemini"),
("moonshot:kimi-k2", "Moonshot Kimi"),
("xai:grok-2", "xAI Grok"),
("deepseek:deepseek-r1", "DeepSeek"),
],
)
def test_each_known_variant_renders_with_its_marker(
self,
fixed_today: datetime,
model_name: str,
expected_marker: str,
) -> None:
"""Every supported variant must produce a ``<provider_hints>`` block
containing its identifying marker. This pins the dispatch + the
on-disk fragments together so a missing/renamed file is caught
immediately.
"""
prompt = compose_system_prompt(today=fixed_today, model_name=model_name)
assert "<provider_hints>" in prompt, (
f"variant for {model_name!r} did not emit a provider_hints block; "
"the corresponding providers/<variant>.md may be missing"
)
assert expected_marker in prompt, (
f"variant for {model_name!r} emitted hints but lacked the "
f"expected marker {expected_marker!r} — the fragment may have "
"drifted from the dispatch table"
)
def test_provider_blocks_are_byte_stable_across_calls(
self, fixed_today: datetime
) -> None:
"""Cache-stability guard: same model id → byte-identical prompt."""
a = compose_system_prompt(today=fixed_today, model_name="moonshot:kimi-k2")
b = compose_system_prompt(today=fixed_today, model_name="moonshot:kimi-k2")
assert a == b
def test_custom_system_instructions_override_default(
self, fixed_today: datetime
) -> None:
custom = "You are a custom assistant. Today is {resolved_today}."
prompt = compose_system_prompt(
today=fixed_today, custom_system_instructions=custom
)
assert "You are a custom assistant. Today is 2025-06-01." in prompt
# Default block should NOT be present
assert "<knowledge_base_only_policy>" not in prompt
def test_provider_hints_render_with_custom_system_instructions(
self, fixed_today: datetime
) -> None:
"""Regression guard for the always-append decision: provider hints
append AFTER a custom system prompt.
Provider hints are stylistic nudges (parallel tool-call rules,
formatting guidance, etc.) that help the model regardless of
what the system instructions say. Suppressing them when a
custom prompt is set would partially defeat the per-family
prompt machinery.
"""
prompt = compose_system_prompt(
today=fixed_today,
custom_system_instructions="You are a custom assistant.",
model_name="anthropic/claude-3-5-sonnet",
)
assert "You are a custom assistant." in prompt
assert "<provider_hints>" in prompt
# The custom prompt must come BEFORE the provider hints so the
# user's framing isn't drowned out by the stylistic nudges.
assert prompt.index("You are a custom assistant.") < prompt.index(
"<provider_hints>"
)
def test_use_default_false_with_no_custom_yields_no_system_block(
self, fixed_today: datetime
) -> None:
prompt = compose_system_prompt(
today=fixed_today,
use_default_system_instructions=False,
)
# No system_instruction wrapper but tools/citations still emitted
assert "<system_instruction>" not in prompt
assert "<tools>" in prompt
def test_all_known_tools_have_fragments(self) -> None:
# Soft assertion: verify that every tool in the canonical order
# produces non-empty content for at least one variant.
for tool in ALL_TOOL_NAMES_ORDERED:
prompt = compose_system_prompt(
today=datetime(2025, 1, 1, tzinfo=UTC),
enabled_tool_names={tool},
)
assert tool in prompt, f"tool {tool!r} missing from composed prompt"
class TestStableOrderingForCacheStability:
"""Regression guard: prompt cache hit-rate depends on byte-stable prefix."""
def test_composition_is_deterministic_given_same_inputs(
self, fixed_today: datetime
) -> None:
a = compose_system_prompt(
today=fixed_today,
enabled_tool_names={"web_search", "scrape_webpage"},
mcp_connector_tools={"X": ["x_a", "x_b"]},
)
b = compose_system_prompt(
today=fixed_today,
enabled_tool_names={
"scrape_webpage",
"web_search",
}, # set order shouldn't matter
mcp_connector_tools={"X": ["x_a", "x_b"]},
)
assert a == b

View file

@ -38,7 +38,7 @@ class TestIsProtectedSystemMessage:
)
def test_tolerates_leading_whitespace(self) -> None:
msg = SystemMessage(content=" \n<priority_documents>\n...")
msg = SystemMessage(content=" \n<workspace_tree>\n...")
assert _is_protected_system_message(msg) is True
@ -89,7 +89,7 @@ class TestPartitionMessages:
def test_protected_system_message_preserved_even_in_summarize_half(self) -> None:
partitioner = self._build_partitioner()
protected = SystemMessage(content="<priority_documents>\n...")
protected = SystemMessage(content="<workspace_tree>\n...")
msgs = [
HumanMessage(content="old human"),
AIMessage(content="old ai"),

View file

@ -28,7 +28,6 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
"SURFSENSE_ENABLE_SKILLS",
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"SURFSENSE_ENABLE_ACTION_LOG",
"SURFSENSE_ENABLE_REVERT_ROUTE",
"SURFSENSE_ENABLE_PLUGIN_LOADER",
@ -57,7 +56,6 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) ->
assert flags.enable_llm_tool_selector is False
assert flags.enable_skills is True
assert flags.enable_specialized_subagents is True
assert flags.enable_kb_planner_runnable is True
assert flags.enable_action_log is True
assert flags.enable_revert_route is True
assert flags.enable_plugin_loader is False
@ -122,7 +120,6 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
"enable_llm_tool_selector": "SURFSENSE_ENABLE_LLM_TOOL_SELECTOR",
"enable_skills": "SURFSENSE_ENABLE_SKILLS",
"enable_specialized_subagents": "SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS",
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",

View file

@ -90,8 +90,8 @@ class TestSubstituteInText:
class TestResolveMentions:
"""``resolve_mentions`` resolves chip ids → virtual paths and emits
a ``ResolvedMentionSet`` whose id partitions feed
``KnowledgePriorityMiddleware``."""
a ``ResolvedMentionSet`` whose id partitions feed the
``search_knowledge_base`` retrieval scope."""
@pytest.mark.asyncio
async def test_returns_empty_when_no_mentions(self):

View file

@ -4,9 +4,14 @@ from __future__ import annotations
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.agents.chat.multi_agent_chat.shared.state.reducers import (
_CLEAR,
_add_unique_reducer,
_citation_registry_merge_reducer,
_dict_merge_with_tombstones_reducer,
_initial_filesystem_state,
_list_append_reducer,
@ -93,6 +98,57 @@ class TestDictMergeWithTombstones:
}
def _kb_registry(chunk_id: int) -> CitationRegistry:
registry = CitationRegistry()
registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
)
return registry
class TestCitationRegistryMergeReducer:
def test_none_left_returns_right(self):
right = _kb_registry(10)
assert _citation_registry_merge_reducer(None, right) is right
def test_none_right_returns_left(self):
left = _kb_registry(10)
assert _citation_registry_merge_reducer(left, None) is left
def test_both_none_returns_none(self):
assert _citation_registry_merge_reducer(None, None) is None
def test_unions_two_registries(self):
left = _kb_registry(10)
right = _kb_registry(11)
merged = _citation_registry_merge_reducer(left, right)
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
assert chunk_ids == {10, 11}
def test_coerces_serialized_dict_update(self):
# The checkpointer serializes Command.update via ormsgpack before the
# reducer runs, so `right` can arrive as a plain dict.
left = _kb_registry(10)
right = _kb_registry(11).model_dump()
merged = _citation_registry_merge_reducer(left, right)
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
assert chunk_ids == {10, 11}
def test_coerces_both_sides_from_dict(self):
left = _kb_registry(10).model_dump()
right = _kb_registry(11).model_dump()
merged = _citation_registry_merge_reducer(left, right)
assert isinstance(merged, CitationRegistry)
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
assert chunk_ids == {10, 11}
class TestInitialFilesystemState:
def test_default_shape(self):
state = _initial_filesystem_state()
@ -105,8 +161,6 @@ class TestInitialFilesystemState:
assert state["doc_id_by_path"] == {}
assert state["dirty_paths"] == []
assert state["dirty_path_tool_calls"] == {}
assert state["kb_priority"] == []
assert state["kb_matched_chunk_ids"] == {}
assert state["kb_anon_doc"] is None
assert state["tree_version"] == 0

View file

@ -0,0 +1,124 @@
"""Unit tests for the KB read path: full-view render + anonymous-doc loading.
DB-backed loads are exercised by the integration suite; here we lock the pure
pieces ``render_full_document`` and the anonymous-upload branch of
``aload_document`` which need no database.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.agents.chat.multi_agent_chat.shared.document_render import (
RenderableDocument,
RenderablePassage,
)
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import (
KBPostgresBackend,
render_full_document,
)
pytestmark = pytest.mark.unit
def _backend(state: dict) -> KBPostgresBackend:
return KBPostgresBackend(search_space_id=1, runtime=SimpleNamespace(state=state))
def test_render_full_document_uses_full_view_and_registers() -> None:
registry = CitationRegistry()
document = RenderableDocument(
title="Launch Notes",
source="Slack",
passages=[
RenderablePassage(
content="push to March 10",
locator={"document_id": 7, "chunk_id": 880},
),
],
)
rendered = render_full_document(document, registry)
assert '<document title="Launch Notes" source="Slack" view="full">' in rendered
assert "[1] push to March 10" in rendered
entry = registry.resolve(1)
assert entry is not None
assert entry.locator == {"document_id": 7, "chunk_id": 880}
def test_render_full_document_reuses_search_label() -> None:
"""A chunk already registered from search keeps its [n] on a later full read."""
registry = CitationRegistry()
n = registry.register(
CitationSourceType.KB_CHUNK,
{"document_id": 7, "chunk_id": 880},
{"title": "Launch Notes", "source": "Slack"},
)
document = RenderableDocument(
title="Launch Notes",
source="Slack",
passages=[
RenderablePassage(
content="new chunk",
locator={"document_id": 7, "chunk_id": 881},
),
RenderablePassage(
content="push to March 10",
locator={"document_id": 7, "chunk_id": 880},
),
],
)
rendered = render_full_document(document, registry)
assert f"[{n}] push to March 10" in rendered
assert "[2] new chunk" in rendered
def test_render_full_document_empty_falls_back_to_notice() -> None:
registry = CitationRegistry()
document = RenderableDocument(title="Empty", passages=[])
assert render_full_document(document, registry) == (
"(This document has no readable content.)"
)
async def test_aload_document_anonymous_upload() -> None:
backend = _backend(
{
"kb_anon_doc": {
"path": "/anon_upload.md",
"title": "Quarterly Report",
"chunks": [
{"chunk_id": -1, "content": "revenue grew"},
{"chunk_id": -2, "content": "costs fell"},
],
}
}
)
loaded = await backend.aload_document("/anon_upload.md")
assert loaded is not None
document, doc_id = loaded
assert doc_id is None
assert document.title == "Quarterly Report"
assert [p.locator["chunk_id"] for p in document.passages] == [-1, -2]
assert all(p.locator["document_id"] == -1 for p in document.passages)
assert all(
p.source_type is CitationSourceType.ANON_CHUNK for p in document.passages
)
async def test_aload_document_unknown_path_returns_none() -> None:
backend = _backend({})
assert await backend.aload_document("/not/under/documents.md") is None

View file

@ -1,689 +0,0 @@
"""Unit tests for knowledge_search middleware helpers."""
import json
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from app.agents.chat.multi_agent_chat.shared.middleware import knowledge_search as ks
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.backends.document_xml import (
build_document_xml as _build_document_xml,
)
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
KBSearchPlan,
KnowledgePriorityMiddleware,
_normalize_optional_date_range,
_parse_kb_search_plan_response,
_render_recent_conversation,
_resolve_search_types,
)
pytestmark = pytest.mark.unit
# ── _resolve_search_types ──────────────────────────────────────────────
class TestResolveSearchTypes:
def test_returns_none_when_no_inputs(self):
assert _resolve_search_types(None, None) is None
def test_returns_none_when_both_empty(self):
assert _resolve_search_types([], []) is None
def test_includes_legacy_type_for_google_gmail(self):
result = _resolve_search_types(["GOOGLE_GMAIL_CONNECTOR"], None)
assert "GOOGLE_GMAIL_CONNECTOR" in result
assert "COMPOSIO_GMAIL_CONNECTOR" in result
def test_includes_legacy_type_for_google_drive(self):
result = _resolve_search_types(None, ["GOOGLE_DRIVE_FILE"])
assert "GOOGLE_DRIVE_FILE" in result
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in result
def test_includes_legacy_type_for_google_calendar(self):
result = _resolve_search_types(["GOOGLE_CALENDAR_CONNECTOR"], None)
assert "GOOGLE_CALENDAR_CONNECTOR" in result
assert "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR" in result
def test_no_legacy_expansion_for_unrelated_types(self):
result = _resolve_search_types(["FILE", "NOTE"], None)
assert set(result) == {"FILE", "NOTE"}
def test_combines_connectors_and_document_types(self):
result = _resolve_search_types(["FILE"], ["NOTE", "CRAWLED_URL"])
assert {"FILE", "NOTE", "CRAWLED_URL"}.issubset(set(result))
def test_deduplicates(self):
result = _resolve_search_types(["FILE", "FILE"], ["FILE"])
assert result.count("FILE") == 1
# ── _build_document_xml ────────────────────────────────────────────────
class TestBuildDocumentXml:
@pytest.fixture
def sample_document(self):
return {
"document_id": 42,
"document": {
"id": 42,
"document_type": "FILE",
"title": "Test Doc",
"metadata": {"url": "https://example.com"},
},
"chunks": [
{"chunk_id": 101, "content": "First chunk content"},
{"chunk_id": 102, "content": "Second chunk content"},
{"chunk_id": 103, "content": "Third chunk content"},
],
}
def test_contains_document_metadata(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<document_id>42</document_id>" in xml
assert "<document_type>FILE</document_type>" in xml
assert "Test Doc" in xml
def test_contains_chunk_index(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<chunk_index>" in xml
assert "</chunk_index>" in xml
assert 'chunk_id="101"' in xml
assert 'chunk_id="102"' in xml
assert 'chunk_id="103"' in xml
def test_matched_chunks_flagged_in_index(self, sample_document):
xml = _build_document_xml(sample_document, matched_chunk_ids={101, 103})
lines = xml.split("\n")
for line in lines:
if 'chunk_id="101"' in line:
assert 'matched="true"' in line
if 'chunk_id="102"' in line:
assert 'matched="true"' not in line
if 'chunk_id="103"' in line:
assert 'matched="true"' in line
def test_chunk_content_in_document_content_section(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<document_content>" in xml
assert "First chunk content" in xml
assert "Second chunk content" in xml
assert "Third chunk content" in xml
def test_line_numbers_in_chunk_index_are_accurate(self, sample_document):
"""Verify that the line ranges in chunk_index actually point to the right content."""
xml = _build_document_xml(sample_document, matched_chunk_ids={101})
xml_lines = xml.split("\n")
for line in xml_lines:
if 'chunk_id="101"' in line and "lines=" in line:
import re
m = re.search(r'lines="(\d+)-(\d+)"', line)
assert m, f"No lines= attribute found in: {line}"
start, _end = int(m.group(1)), int(m.group(2))
target_line = xml_lines[start - 1]
assert "101" in target_line
assert "First chunk content" in target_line
break
else:
pytest.fail("chunk_id=101 entry not found in chunk_index")
def test_splits_into_lines_correctly(self, sample_document):
"""Each chunk occupies exactly one line (no embedded newlines)."""
xml = _build_document_xml(sample_document)
lines = xml.split("\n")
chunk_lines = [
line for line in lines if "<![CDATA[" in line and "<chunk" in line
]
assert len(chunk_lines) == 3
# ── planner parsing / date normalization ───────────────────────────────
class TestPlannerHelpers:
def test_parse_kb_search_plan_response_accepts_plain_json(self):
plan = _parse_kb_search_plan_response(
json.dumps(
{
"optimized_query": "ocv meeting decisions summary",
"start_date": "2026-03-01",
"end_date": "2026-03-31",
}
)
)
assert plan.optimized_query == "ocv meeting decisions summary"
assert plan.start_date == "2026-03-01"
assert plan.end_date == "2026-03-31"
def test_parse_kb_search_plan_response_accepts_fenced_json(self):
plan = _parse_kb_search_plan_response(
"""```json
{"optimized_query":"deel founders guide","start_date":null,"end_date":null}
```"""
)
assert plan.optimized_query == "deel founders guide"
assert plan.start_date is None
assert plan.end_date is None
def test_normalize_optional_date_range_returns_none_when_absent(self):
start_date, end_date = _normalize_optional_date_range(None, None)
assert start_date is None
assert end_date is None
def test_normalize_optional_date_range_resolves_single_bound(self):
start_date, end_date = _normalize_optional_date_range("2026-03-01", None)
assert start_date is not None
assert end_date is not None
assert start_date.date().isoformat() == "2026-03-01"
assert end_date >= start_date
class FakeLLM:
def __init__(self, response_text: str):
self.response_text = response_text
self.calls: list[dict] = []
async def ainvoke(self, messages, config=None):
self.calls.append({"messages": messages, "config": config})
return AIMessage(content=self.response_text)
class FakeBudgetLLM:
def __init__(self, *, max_input_tokens: int):
self._max_input_tokens_value = max_input_tokens
def _get_max_input_tokens(self) -> int:
return self._max_input_tokens_value
def _count_tokens(self, messages) -> int:
# Deterministic, simple proxy for tests: count characters as tokens.
return sum(len(msg.get("content", "")) for msg in messages)
class TestKnowledgePriorityMiddlewarePlanner:
@pytest.fixture(autouse=True)
def _disable_planner_runnable(self, monkeypatch):
# ``FakeLLM`` is a duck-typed mock; ``create_agent`` (used when the
# planner Runnable path is enabled) calls ``.bind()`` on the LLM,
# which the mock does not implement. Pin the flag off so the
# planner falls through to the legacy ``self.llm.ainvoke`` path
# these tests assert against (``llm.calls[0]["config"]``).
monkeypatch.setenv("SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", "false")
def test_render_recent_conversation_prefers_latest_messages_under_budget(self):
messages = [
HumanMessage(content="old user context " * 40),
AIMessage(content="old assistant answer " * 35),
HumanMessage(content="recent user context " * 20),
AIMessage(content="recent assistant answer " * 18),
HumanMessage(content="latest question"),
]
rendered = _render_recent_conversation(
messages,
llm=FakeBudgetLLM(max_input_tokens=900),
user_text="latest question",
)
assert "recent user context" in rendered
assert "recent assistant answer" in rendered
assert "latest question" not in rendered
assert rendered.index("recent user context") < rendered.index(
"recent assistant answer"
)
def test_render_recent_conversation_falls_back_to_legacy_without_budgeting(self):
messages = [
HumanMessage(content="message one"),
AIMessage(content="message two"),
HumanMessage(content="latest question"),
]
rendered = _render_recent_conversation(
messages,
llm=None,
user_text="latest question",
)
assert "user: message one" in rendered
assert "assistant: message two" in rendered
assert "latest question" not in rendered
async def test_middleware_uses_optimized_query_and_dates(self, monkeypatch):
captured: dict = {}
async def fake_search_knowledge_base(**kwargs):
captured.update(kwargs)
return []
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
llm = FakeLLM(
json.dumps(
{
"optimized_query": "ocv meeting decisions action items",
"start_date": "2026-03-01",
"end_date": "2026-03-31",
}
)
)
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=37)
result = await middleware.abefore_agent(
{
"messages": [
HumanMessage(content="what happened in our OCV meeting last month?")
]
},
runtime=None,
)
assert result is not None
assert captured["query"] == "ocv meeting decisions action items"
assert captured["start_date"] is not None
assert captured["end_date"] is not None
assert captured["start_date"].date().isoformat() == "2026-03-01"
assert captured["end_date"].date().isoformat() == "2026-03-31"
assert llm.calls[0]["config"] == {"tags": ["surfsense:internal"]}
async def test_middleware_falls_back_when_planner_returns_invalid_json(
self,
monkeypatch,
):
captured: dict = {}
async def fake_search_knowledge_base(**kwargs):
captured.update(kwargs)
return []
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgePriorityMiddleware(
llm=FakeLLM("not json"),
search_space_id=37,
)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
runtime=None,
)
assert captured["query"] == "summarize founders guide by deel"
assert captured["start_date"] is None
assert captured["end_date"] is None
async def test_middleware_passes_none_dates_when_planner_returns_nulls(
self,
monkeypatch,
):
captured: dict = {}
async def fake_search_knowledge_base(**kwargs):
captured.update(kwargs)
return []
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgePriorityMiddleware(
llm=FakeLLM(
json.dumps(
{
"optimized_query": "deel founders guide summary",
"start_date": None,
"end_date": None,
}
)
),
search_space_id=37,
)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="summarize founders guide by deel")]},
runtime=None,
)
assert captured["query"] == "deel founders guide summary"
assert captured["start_date"] is None
assert captured["end_date"] is None
async def test_middleware_routes_to_recency_browse_when_flagged(
self,
monkeypatch,
):
"""When the planner sets is_recency_query=true, browse_recent_documents
is called instead of search_knowledge_base."""
browse_captured: dict = {}
search_called = False
async def fake_browse_recent_documents(**kwargs):
browse_captured.update(kwargs)
return []
async def fake_search_knowledge_base(**kwargs):
nonlocal search_called
search_called = True
return []
monkeypatch.setattr(
ks,
"browse_recent_documents",
fake_browse_recent_documents,
)
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
llm = FakeLLM(
json.dumps(
{
"optimized_query": "latest uploaded file",
"start_date": None,
"end_date": None,
"is_recency_query": True,
}
)
)
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42)
result = await middleware.abefore_agent(
{"messages": [HumanMessage(content="what's my latest file?")]},
runtime=None,
)
assert result is not None
assert browse_captured["search_space_id"] == 42
assert not search_called
async def test_middleware_uses_hybrid_search_when_not_recency(
self,
monkeypatch,
):
"""When is_recency_query is false (default), hybrid search is used."""
search_captured: dict = {}
browse_called = False
async def fake_browse_recent_documents(**kwargs):
nonlocal browse_called
browse_called = True
return []
async def fake_search_knowledge_base(**kwargs):
search_captured.update(kwargs)
return []
monkeypatch.setattr(
ks,
"browse_recent_documents",
fake_browse_recent_documents,
)
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
llm = FakeLLM(
json.dumps(
{
"optimized_query": "quarterly revenue report analysis",
"start_date": None,
"end_date": None,
"is_recency_query": False,
}
)
)
middleware = KnowledgePriorityMiddleware(llm=llm, search_space_id=42)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="find the quarterly revenue report")]},
runtime=None,
)
assert search_captured["query"] == "quarterly revenue report analysis"
assert not browse_called
# ── KBSearchPlan schema ────────────────────────────────────────────────
class TestKBSearchPlanSchema:
def test_is_recency_query_defaults_to_false(self):
plan = KBSearchPlan(optimized_query="test query")
assert plan.is_recency_query is False
def test_is_recency_query_parses_true(self):
plan = _parse_kb_search_plan_response(
json.dumps(
{
"optimized_query": "latest uploaded file",
"start_date": None,
"end_date": None,
"is_recency_query": True,
}
)
)
assert plan.is_recency_query is True
assert plan.optimized_query == "latest uploaded file"
def test_missing_is_recency_query_defaults_to_false(self):
plan = _parse_kb_search_plan_response(
json.dumps(
{
"optimized_query": "meeting notes",
"start_date": None,
"end_date": None,
}
)
)
assert plan.is_recency_query is False
# ── mentioned_document_ids cross-turn drain ────────────────────────────
class TestKnowledgePriorityMentionDrain:
"""Regression tests for the cross-turn ``mentioned_document_ids`` drain.
The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware`
instance across turns of the same thread. ``mentioned_document_ids``
can therefore enter the middleware via two paths:
1. The constructor closure (``__init__(mentioned_document_ids=...)``)
seeded by the cache-miss build on turn 1.
2. ``runtime.context.mentioned_document_ids`` supplied freshly per
turn by the streaming task.
Without the drain fix, an empty ``runtime.context.mentioned_document_ids``
on turn 2 would fall through to the closure (because ``[]`` is falsy in
Python) and replay turn 1's mentions. This class pins down the
correct behaviour: the runtime path is authoritative even when empty,
and the closure is drained the first time the runtime path fires so
no later turn can ever resurrect stale state.
"""
@staticmethod
def _make_runtime(mention_ids: list[int]):
"""Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``."""
from types import SimpleNamespace
return SimpleNamespace(
context=SimpleNamespace(mentioned_document_ids=mention_ids),
)
@staticmethod
def _planner_llm() -> "FakeLLM":
# Planner returns a stable, non-recency plan so we always land in
# the hybrid-search branch (where ``fetch_mentioned_documents`` is
# invoked alongside the main search).
return FakeLLM(
json.dumps(
{
"optimized_query": "follow up question",
"start_date": None,
"end_date": None,
"is_recency_query": False,
}
)
)
async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch):
"""Turn 1 with mentions in BOTH closure and runtime context: the
runtime path wins AND the closure is drained so a future turn
cannot replay it.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
ks,
"fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgePriorityMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[1, 2, 3],
)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="what is in those docs?")]},
runtime=self._make_runtime([1, 2, 3]),
)
assert fetched_ids == [[1, 2, 3]], (
"runtime.context mentions must be the source of truth on turn 1"
)
assert middleware.mentioned_document_ids == [], (
"closure must be drained the first time the runtime path fires "
"so no later turn can replay stale mentions"
)
async def test_empty_runtime_context_does_not_replay_closure_mentions(
self, monkeypatch
):
"""Regression: turn 2 with NO mentions must not surface turn 1's
mentions from the constructor closure.
Before the fix, ``if ctx_mentions:`` treated an empty list as
absent and fell through to ``elif self.mentioned_document_ids:``,
replaying turn 1's mentions. This test pins down the corrected
behaviour.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
ks,
"fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
# Simulate a cached middleware instance whose closure was seeded
# by a previous turn's cache-miss build (mentions=[1,2,3]).
middleware = KnowledgePriorityMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[1, 2, 3],
)
# Turn 2: streaming task supplies an EMPTY mention list (no
# mentions on this follow-up turn).
await middleware.abefore_agent(
{"messages": [HumanMessage(content="what about the next steps?")]},
runtime=self._make_runtime([]),
)
assert fetched_ids == [], (
"fetch_mentioned_documents must NOT be called when the runtime "
"context says there are no mentions for this turn"
)
async def test_legacy_path_fires_only_when_runtime_context_absent(
self, monkeypatch
):
"""Backward-compat: if a caller doesn't supply runtime.context (old
non-streaming code path), the closure-injected mentions are still
honoured exactly once and then drained.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
ks,
"fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
ks,
"search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgePriorityMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[7, 8],
)
# First call: no runtime → legacy path uses the closure.
await middleware.abefore_agent(
{"messages": [HumanMessage(content="initial question")]},
runtime=None,
)
# Second call: still no runtime — closure already drained, so no replay.
await middleware.abefore_agent(
{"messages": [HumanMessage(content="follow up")]},
runtime=None,
)
assert fetched_ids == [[7, 8]], (
"legacy path must honour the closure exactly once and then drain it"
)
assert middleware.mentioned_document_ids == []

View file

@ -0,0 +1,85 @@
"""Behavior tests for finalize-time citation resolution.
The finalize step is the single server-side seam that turns the model's bare
``[n]`` ordinals into renderer-ready ``[citation:<payload>]`` markers, using the
registry captured from the run's final state. These tests pin that contract:
known ordinals resolve, unknown ones drop, foreign markers survive, and a
serialized (dict) registry is accepted just like a live one.
"""
from __future__ import annotations
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.tasks.chat.streaming.flows.shared.assistant_finalize import _resolve_citations
def _registry_with_chunk(chunk_id: int = 42) -> CitationRegistry:
registry = CitationRegistry()
registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
)
return registry
def _text(value: str) -> list[dict]:
return [{"type": "text", "text": value}]
def test_known_ordinal_resolves_to_chunk_marker():
payload = _resolve_citations(
_text("Launch is March 10 [1]."), _registry_with_chunk(42)
)
assert payload[0]["text"] == "Launch is March 10 [citation:42]."
def test_unknown_ordinal_is_dropped():
payload = _resolve_citations(
_text("Unsupported claim [9]."), _registry_with_chunk(42)
)
assert payload[0]["text"] == "Unsupported claim ."
def test_foreign_citation_marker_is_preserved():
payload = _resolve_citations(
_text("From the web [citation:https://example.com]."),
_registry_with_chunk(42),
)
assert payload[0]["text"] == "From the web [citation:https://example.com]."
def test_serialized_registry_is_accepted():
serialized = _registry_with_chunk(7).model_dump()
payload = _resolve_citations(_text("See [1]."), serialized)
assert payload[0]["text"] == "See [citation:7]."
def test_empty_registry_leaves_text_untouched():
payload = _resolve_citations(_text("No sources here [1]."), CitationRegistry())
assert payload[0]["text"] == "No sources here [1]."
def test_missing_registry_is_a_noop():
payload = _resolve_citations(_text("Nothing to resolve [1]."), None)
assert payload[0]["text"] == "Nothing to resolve [1]."
def test_non_text_parts_are_left_alone():
parts = [
{"type": "tool_call", "name": "search_knowledge_base", "args": {"q": "[1]"}},
{"type": "text", "text": "Result [1]."},
]
payload = _resolve_citations(parts, _registry_with_chunk(5))
assert payload[0]["args"]["q"] == "[1]"
assert payload[1]["text"] == "Result [citation:5]."