From 04a76b163b277e4384a5609c980455c93aa83e5d Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 25 Jun 2026 15:26:25 +0200 Subject: [PATCH] citations: conversation-scoped registry with finalize-time [n] resolution Add the checkpointed CitationRegistry (load/merge helpers + state field) and a lightweight CitationStateMiddleware so subagents can register into the same conversation registry. Resolve [n] -> [citation:] at stream finalize from the registry, polymorphically by source type. --- .../shared/citations/__init__.py | 2 + .../shared/citations/registry.py | 30 +++++++ .../shared/citations/state.py | 26 ++++++ .../shared/middleware/citation_state.py | 50 +++++++++++ .../tasks/chat/streaming/agent/event_loop.py | 1 + .../flows/shared/assistant_finalize.py | 38 ++++++++- .../chat/streaming/shared/stream_result.py | 4 + .../shared/citations/test_registry.py | 70 +++++++++++++++ .../test_assistant_finalize_citations.py | 85 +++++++++++++++++++ 9 files changed, 305 insertions(+), 1 deletion(-) create mode 100644 surfsense_backend/app/agents/chat/multi_agent_chat/shared/citations/state.py create mode 100644 surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/citation_state.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/streaming/flows/shared/test_assistant_finalize_citations.py diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/citations/__init__.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/citations/__init__.py index 91640483b..a329d6042 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/citations/__init__.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/citations/__init__.py @@ -9,11 +9,13 @@ from .markers import to_frontend_payload from .models import CitationEntry, CitationSourceType from .normalizer import normalize_citations from .registry import CitationRegistry, make_key +from .state import load_registry __all__ = [ "CitationEntry", "CitationRegistry", "CitationSourceType", + "load_registry", "make_key", "normalize_citations", "to_frontend_payload", diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/citations/registry.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/citations/registry.py index a4035be4d..4d56bc088 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/citations/registry.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/citations/registry.py @@ -57,5 +57,35 @@ class CitationRegistry(BaseModel): """Map ``[n]`` back to its source; unknown → ``None`` so bad citations drop.""" return self.by_n.get(n) + def merge(self, other: CitationRegistry) -> CitationRegistry: + """Union ``self`` with ``other`` (find-or-create), returning a new registry. + + Needed because separate branches (parent + subagents, parallel tool calls) + each register into a registry forked from the same base. A plain replace + would drop one branch's mappings; this unions them so ``[n]`` stays globally + consistent and no source is lost: + + - A source already in ``self`` keeps its existing ``[n]``. + - A source only in ``other`` keeps its ``[n]`` when that slot is free. + - A collision (same ``[n]``, different source on each side) re-mints the + ``other`` entry to a fresh ``[n]`` and advances ``next_n`` past both. + + Pure: neither registry is mutated. Entries are folded in ascending ``[n]`` + order so the result is deterministic. + """ + merged = self.model_copy(deep=True) + for n in sorted(other.by_n): + entry = other.by_n[n] + key = make_key(entry.source_type, entry.locator) + if key in merged.by_key: + continue + if n in merged.by_n: + merged.register(entry.source_type, entry.locator, entry.display) + else: + merged.by_n[n] = entry.model_copy(deep=True) + merged.by_key[key] = n + merged.next_n = max(merged.next_n, n + 1) + return merged + __all__ = ["CitationRegistry", "make_key"] diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/citations/state.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/citations/state.py new file mode 100644 index 000000000..0df103a54 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/citations/state.py @@ -0,0 +1,26 @@ +"""Read the conversation's ``CitationRegistry`` out of graph state. + +The registry is checkpointed, so it may come back as a live ``CitationRegistry`` +or a plain dict (after (de)serialization). Both the search tool and the read +path load it the same way before registering new ``[n]`` and writing it back. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from .registry import CitationRegistry + + +def load_registry(state: Mapping[str, Any] | None) -> CitationRegistry: + """Return the registry from ``state``, tolerating a serialized dict or absence.""" + raw = state.get("citation_registry") if state else None + if isinstance(raw, CitationRegistry): + return raw + if isinstance(raw, dict): + return CitationRegistry.model_validate(raw) + return CitationRegistry() + + +__all__ = ["load_registry"] diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/citation_state.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/citation_state.py new file mode 100644 index 000000000..e9cb54957 --- /dev/null +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/citation_state.py @@ -0,0 +1,50 @@ +"""Contribute the ``citation_registry`` state channel to a subagent. + +The conversation's ``[n]`` -> source registry lives on graph state behind a +merge reducer (see :mod:`app.agents.chat.multi_agent_chat.shared.state.reducers`). +The orchestrator and the KB subagent get that channel for free via the filesystem +state schema, but a citable subagent that does *not* use the filesystem (e.g. +``research``) still needs the channel declared so its tools can register ``[n]`` +via ``Command(update={"citation_registry": ...})`` and have it merge back up. + +This middleware adds *only* that channel — no tools, no behavior — so any subagent +that mints citations can opt in without inheriting filesystem semantics. +""" + +from __future__ import annotations + +from typing import Annotated, NotRequired + +from langchain.agents.middleware import AgentMiddleware +from typing_extensions import TypedDict + +from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry +from app.agents.chat.multi_agent_chat.shared.state.reducers import ( + _citation_registry_merge_reducer, +) + + +class CitationState(TypedDict): + """State carrying just the per-conversation ``[n]`` -> source registry.""" + + citation_registry: NotRequired[ + Annotated[CitationRegistry, _citation_registry_merge_reducer] + ] + + +class CitationStateMiddleware(AgentMiddleware): # type: ignore[type-arg] + """Declare the ``citation_registry`` channel; no tools, no hooks.""" + + tools = () + state_schema = CitationState + + +def build_citation_state_mw() -> CitationStateMiddleware: + return CitationStateMiddleware() + + +__all__ = [ + "CitationState", + "CitationStateMiddleware", + "build_citation_state_mw", +] diff --git a/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py index 939cd9b17..5ffe46280 100644 --- a/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py +++ b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py @@ -81,6 +81,7 @@ async def stream_agent_events( result.final_message_parts = final_assistant_parts_from_messages( state_values.get("messages") ) + result.citation_registry = state_values.get("citation_registry") # Safety net: if astream_events was cancelled before # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py index 3f767c60b..c59c2dcda 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py @@ -22,8 +22,12 @@ Never raises (best-effort, logs only). from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +from app.agents.chat.multi_agent_chat.shared.citations import ( + CitationRegistry, + normalize_citations, +) from app.tasks.chat.streaming.shared.stream_result import StreamResult from app.utils.perf import get_perf_logger @@ -33,6 +37,35 @@ if TYPE_CHECKING: _perf_log = get_perf_logger() +def _as_registry(raw: Any) -> CitationRegistry | None: + """Coerce the captured state value into a registry, tolerating a serialized dict.""" + if isinstance(raw, CitationRegistry): + return raw + if isinstance(raw, dict): + try: + return CitationRegistry.model_validate(raw) + except Exception: + return None + return None + + +def _resolve_citations( + content_payload: list[dict[str, Any]], raw_registry: Any +) -> list[dict[str, Any]]: + """Rewrite ``[n]`` -> ``[citation:]`` in each text part before persisting. + + No-op when the turn registered no citable sources; ``web_search``'s existing + ``[citation:url]`` markers pass through untouched (the regex matches bare ``[n]``). + """ + registry = _as_registry(raw_registry) + if registry is None or not registry.by_n: + return content_payload + for part in content_payload: + if part.get("type") == "text" and isinstance(part.get("text"), str): + part["text"] = normalize_citations(part["text"], registry) + return content_payload + + async def finalize_assistant_message( *, stream_result: StreamResult | None, @@ -79,6 +112,9 @@ async def finalize_assistant_message( content_payload, stream_result.final_message_parts, ) + content_payload = _resolve_citations( + content_payload, stream_result.citation_registry + ) if builder_stats is not None: _perf_log.info( diff --git a/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py b/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py index 5e164070a..96fc75708 100644 --- a/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py +++ b/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py @@ -39,3 +39,7 @@ class StreamResult: # state. Used after streaming completes as a provider-agnostic persistence # backfill when no text chunks reached the live stream. final_message_parts: list[dict[str, Any]] = field(default_factory=list) + # Per-conversation citation registry captured from the final LangGraph state + # (a ``CitationRegistry`` or its serialized dict). Read at finalize to rewrite + # the model's ``[n]`` ordinals into ``[citation:]`` markers. + citation_registry: Any | None = field(default=None, repr=False) diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/shared/citations/test_registry.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/shared/citations/test_registry.py index ff90c445b..6363ec897 100644 --- a/surfsense_backend/tests/unit/agents/multi_agent_chat/shared/citations/test_registry.py +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/shared/citations/test_registry.py @@ -102,3 +102,73 @@ def test_make_key_is_stable_and_type_prefixed() -> None: assert key_a == key_b assert key_a.startswith("kb_chunk|") + + +def _kb(registry: CitationRegistry, chunk_id: int) -> int: + return registry.register( + CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id} + ) + + +def test_merge_unions_disjoint_registries_preserving_labels() -> None: + left = CitationRegistry() + _kb(left, 10) # [1] + _kb(left, 11) # [2] + + # A branch that forked from `left`, then registered its own chunk at [3]. + right = left.model_copy(deep=True) + third = _kb(right, 12) # [3] + assert third == 3 + + merged = left.merge(right) + + assert merged.resolve(1).locator["chunk_id"] == 10 + assert merged.resolve(2).locator["chunk_id"] == 11 + assert merged.resolve(3).locator["chunk_id"] == 12 + assert merged.next_n == 4 + + +def test_merge_keeps_one_label_for_a_shared_source() -> None: + left = CitationRegistry() + _kb(left, 10) # [1] + right = CitationRegistry() + _kb(right, 10) # also [1], same source + + merged = left.merge(right) + + assert len(merged.by_n) == 1 + assert merged.resolve(1).locator["chunk_id"] == 10 + assert merged.next_n == 2 + + +def test_merge_remints_on_collision_without_losing_sources() -> None: + # Two branches forked from the same base [1], each minting a *different* + # source at [2]. Merge must keep both sources, re-minting one. + base = CitationRegistry() + _kb(base, 10) # [1] + + left = base.model_copy(deep=True) + _kb(left, 11) # [2] -> chunk 11 + + right = base.model_copy(deep=True) + _kb(right, 12) # [2] -> chunk 12 (collision) + + merged = left.merge(right) + + chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()} + assert chunk_ids == {10, 11, 12} + assert merged.resolve(2).locator["chunk_id"] == 11 # left wins the slot + assert merged.resolve(3).locator["chunk_id"] == 12 # right re-minted + assert merged.next_n == 4 + + +def test_merge_does_not_mutate_inputs() -> None: + left = CitationRegistry() + _kb(left, 10) + right = CitationRegistry() + _kb(right, 11) + + left.merge(right) + + assert list(left.by_n) == [1] + assert list(right.by_n) == [1] diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/flows/shared/test_assistant_finalize_citations.py b/surfsense_backend/tests/unit/tasks/chat/streaming/flows/shared/test_assistant_finalize_citations.py new file mode 100644 index 000000000..437cbc528 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/flows/shared/test_assistant_finalize_citations.py @@ -0,0 +1,85 @@ +"""Behavior tests for finalize-time citation resolution. + +The finalize step is the single server-side seam that turns the model's bare +``[n]`` ordinals into renderer-ready ``[citation:]`` markers, using the +registry captured from the run's final state. These tests pin that contract: +known ordinals resolve, unknown ones drop, foreign markers survive, and a +serialized (dict) registry is accepted just like a live one. +""" + +from __future__ import annotations + +from app.agents.chat.multi_agent_chat.shared.citations import ( + CitationRegistry, + CitationSourceType, +) +from app.tasks.chat.streaming.flows.shared.assistant_finalize import _resolve_citations + + +def _registry_with_chunk(chunk_id: int = 42) -> CitationRegistry: + registry = CitationRegistry() + registry.register( + CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id} + ) + return registry + + +def _text(value: str) -> list[dict]: + return [{"type": "text", "text": value}] + + +def test_known_ordinal_resolves_to_chunk_marker(): + payload = _resolve_citations( + _text("Launch is March 10 [1]."), _registry_with_chunk(42) + ) + + assert payload[0]["text"] == "Launch is March 10 [citation:42]." + + +def test_unknown_ordinal_is_dropped(): + payload = _resolve_citations( + _text("Unsupported claim [9]."), _registry_with_chunk(42) + ) + + assert payload[0]["text"] == "Unsupported claim ." + + +def test_foreign_citation_marker_is_preserved(): + payload = _resolve_citations( + _text("From the web [citation:https://example.com]."), + _registry_with_chunk(42), + ) + + assert payload[0]["text"] == "From the web [citation:https://example.com]." + + +def test_serialized_registry_is_accepted(): + serialized = _registry_with_chunk(7).model_dump() + + payload = _resolve_citations(_text("See [1]."), serialized) + + assert payload[0]["text"] == "See [citation:7]." + + +def test_empty_registry_leaves_text_untouched(): + payload = _resolve_citations(_text("No sources here [1]."), CitationRegistry()) + + assert payload[0]["text"] == "No sources here [1]." + + +def test_missing_registry_is_a_noop(): + payload = _resolve_citations(_text("Nothing to resolve [1]."), None) + + assert payload[0]["text"] == "Nothing to resolve [1]." + + +def test_non_text_parts_are_left_alone(): + parts = [ + {"type": "tool_call", "name": "search_knowledge_base", "args": {"q": "[1]"}}, + {"type": "text", "text": "Result [1]."}, + ] + + payload = _resolve_citations(parts, _registry_with_chunk(5)) + + assert payload[0]["args"]["q"] == "[1]" + assert payload[1]["text"] == "Result [citation:5]."