citations: conversation-scoped registry with finalize-time [n] resolution

Add the checkpointed CitationRegistry (load/merge helpers + state field)
and a lightweight CitationStateMiddleware so subagents can register into
the same conversation registry. Resolve [n] -> [citation:<payload>] at
stream finalize from the registry, polymorphically by source type.
This commit is contained in:
CREDO23 2026-06-25 15:26:25 +02:00
parent 265888d21c
commit 04a76b163b
9 changed files with 305 additions and 1 deletions

View file

@ -9,11 +9,13 @@ from .markers import to_frontend_payload
from .models import CitationEntry, CitationSourceType
from .normalizer import normalize_citations
from .registry import CitationRegistry, make_key
from .state import load_registry
__all__ = [
"CitationEntry",
"CitationRegistry",
"CitationSourceType",
"load_registry",
"make_key",
"normalize_citations",
"to_frontend_payload",

View file

@ -57,5 +57,35 @@ class CitationRegistry(BaseModel):
"""Map ``[n]`` back to its source; unknown → ``None`` so bad citations drop."""
return self.by_n.get(n)
def merge(self, other: CitationRegistry) -> CitationRegistry:
"""Union ``self`` with ``other`` (find-or-create), returning a new registry.
Needed because separate branches (parent + subagents, parallel tool calls)
each register into a registry forked from the same base. A plain replace
would drop one branch's mappings; this unions them so ``[n]`` stays globally
consistent and no source is lost:
- A source already in ``self`` keeps its existing ``[n]``.
- A source only in ``other`` keeps its ``[n]`` when that slot is free.
- A collision (same ``[n]``, different source on each side) re-mints the
``other`` entry to a fresh ``[n]`` and advances ``next_n`` past both.
Pure: neither registry is mutated. Entries are folded in ascending ``[n]``
order so the result is deterministic.
"""
merged = self.model_copy(deep=True)
for n in sorted(other.by_n):
entry = other.by_n[n]
key = make_key(entry.source_type, entry.locator)
if key in merged.by_key:
continue
if n in merged.by_n:
merged.register(entry.source_type, entry.locator, entry.display)
else:
merged.by_n[n] = entry.model_copy(deep=True)
merged.by_key[key] = n
merged.next_n = max(merged.next_n, n + 1)
return merged
__all__ = ["CitationRegistry", "make_key"]

View file

@ -0,0 +1,26 @@
"""Read the conversation's ``CitationRegistry`` out of graph state.
The registry is checkpointed, so it may come back as a live ``CitationRegistry``
or a plain dict (after (de)serialization). Both the search tool and the read
path load it the same way before registering new ``[n]`` and writing it back.
"""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from .registry import CitationRegistry
def load_registry(state: Mapping[str, Any] | None) -> CitationRegistry:
"""Return the registry from ``state``, tolerating a serialized dict or absence."""
raw = state.get("citation_registry") if state else None
if isinstance(raw, CitationRegistry):
return raw
if isinstance(raw, dict):
return CitationRegistry.model_validate(raw)
return CitationRegistry()
__all__ = ["load_registry"]

View file

@ -0,0 +1,50 @@
"""Contribute the ``citation_registry`` state channel to a subagent.
The conversation's ``[n]`` -> source registry lives on graph state behind a
merge reducer (see :mod:`app.agents.chat.multi_agent_chat.shared.state.reducers`).
The orchestrator and the KB subagent get that channel for free via the filesystem
state schema, but a citable subagent that does *not* use the filesystem (e.g.
``research``) still needs the channel declared so its tools can register ``[n]``
via ``Command(update={"citation_registry": ...})`` and have it merge back up.
This middleware adds *only* that channel no tools, no behavior so any subagent
that mints citations can opt in without inheriting filesystem semantics.
"""
from __future__ import annotations
from typing import Annotated, NotRequired
from langchain.agents.middleware import AgentMiddleware
from typing_extensions import TypedDict
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
from app.agents.chat.multi_agent_chat.shared.state.reducers import (
_citation_registry_merge_reducer,
)
class CitationState(TypedDict):
"""State carrying just the per-conversation ``[n]`` -> source registry."""
citation_registry: NotRequired[
Annotated[CitationRegistry, _citation_registry_merge_reducer]
]
class CitationStateMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Declare the ``citation_registry`` channel; no tools, no hooks."""
tools = ()
state_schema = CitationState
def build_citation_state_mw() -> CitationStateMiddleware:
return CitationStateMiddleware()
__all__ = [
"CitationState",
"CitationStateMiddleware",
"build_citation_state_mw",
]

View file

@ -81,6 +81,7 @@ async def stream_agent_events(
result.final_message_parts = final_assistant_parts_from_messages(
state_values.get("messages")
)
result.citation_registry = state_values.get("citation_registry")
# Safety net: if astream_events was cancelled before
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work

View file

@ -22,8 +22,12 @@ Never raises (best-effort, logs only).
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
normalize_citations,
)
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.utils.perf import get_perf_logger
@ -33,6 +37,35 @@ if TYPE_CHECKING:
_perf_log = get_perf_logger()
def _as_registry(raw: Any) -> CitationRegistry | None:
"""Coerce the captured state value into a registry, tolerating a serialized dict."""
if isinstance(raw, CitationRegistry):
return raw
if isinstance(raw, dict):
try:
return CitationRegistry.model_validate(raw)
except Exception:
return None
return None
def _resolve_citations(
content_payload: list[dict[str, Any]], raw_registry: Any
) -> list[dict[str, Any]]:
"""Rewrite ``[n]`` -> ``[citation:<payload>]`` in each text part before persisting.
No-op when the turn registered no citable sources; ``web_search``'s existing
``[citation:url]`` markers pass through untouched (the regex matches bare ``[n]``).
"""
registry = _as_registry(raw_registry)
if registry is None or not registry.by_n:
return content_payload
for part in content_payload:
if part.get("type") == "text" and isinstance(part.get("text"), str):
part["text"] = normalize_citations(part["text"], registry)
return content_payload
async def finalize_assistant_message(
*,
stream_result: StreamResult | None,
@ -79,6 +112,9 @@ async def finalize_assistant_message(
content_payload,
stream_result.final_message_parts,
)
content_payload = _resolve_citations(
content_payload, stream_result.citation_registry
)
if builder_stats is not None:
_perf_log.info(

View file

@ -39,3 +39,7 @@ class StreamResult:
# state. Used after streaming completes as a provider-agnostic persistence
# backfill when no text chunks reached the live stream.
final_message_parts: list[dict[str, Any]] = field(default_factory=list)
# Per-conversation citation registry captured from the final LangGraph state
# (a ``CitationRegistry`` or its serialized dict). Read at finalize to rewrite
# the model's ``[n]`` ordinals into ``[citation:<payload>]`` markers.
citation_registry: Any | None = field(default=None, repr=False)

View file

@ -102,3 +102,73 @@ def test_make_key_is_stable_and_type_prefixed() -> None:
assert key_a == key_b
assert key_a.startswith("kb_chunk|")
def _kb(registry: CitationRegistry, chunk_id: int) -> int:
return registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
)
def test_merge_unions_disjoint_registries_preserving_labels() -> None:
left = CitationRegistry()
_kb(left, 10) # [1]
_kb(left, 11) # [2]
# A branch that forked from `left`, then registered its own chunk at [3].
right = left.model_copy(deep=True)
third = _kb(right, 12) # [3]
assert third == 3
merged = left.merge(right)
assert merged.resolve(1).locator["chunk_id"] == 10
assert merged.resolve(2).locator["chunk_id"] == 11
assert merged.resolve(3).locator["chunk_id"] == 12
assert merged.next_n == 4
def test_merge_keeps_one_label_for_a_shared_source() -> None:
left = CitationRegistry()
_kb(left, 10) # [1]
right = CitationRegistry()
_kb(right, 10) # also [1], same source
merged = left.merge(right)
assert len(merged.by_n) == 1
assert merged.resolve(1).locator["chunk_id"] == 10
assert merged.next_n == 2
def test_merge_remints_on_collision_without_losing_sources() -> None:
# Two branches forked from the same base [1], each minting a *different*
# source at [2]. Merge must keep both sources, re-minting one.
base = CitationRegistry()
_kb(base, 10) # [1]
left = base.model_copy(deep=True)
_kb(left, 11) # [2] -> chunk 11
right = base.model_copy(deep=True)
_kb(right, 12) # [2] -> chunk 12 (collision)
merged = left.merge(right)
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
assert chunk_ids == {10, 11, 12}
assert merged.resolve(2).locator["chunk_id"] == 11 # left wins the slot
assert merged.resolve(3).locator["chunk_id"] == 12 # right re-minted
assert merged.next_n == 4
def test_merge_does_not_mutate_inputs() -> None:
left = CitationRegistry()
_kb(left, 10)
right = CitationRegistry()
_kb(right, 11)
left.merge(right)
assert list(left.by_n) == [1]
assert list(right.by_n) == [1]

View file

@ -0,0 +1,85 @@
"""Behavior tests for finalize-time citation resolution.
The finalize step is the single server-side seam that turns the model's bare
``[n]`` ordinals into renderer-ready ``[citation:<payload>]`` markers, using the
registry captured from the run's final state. These tests pin that contract:
known ordinals resolve, unknown ones drop, foreign markers survive, and a
serialized (dict) registry is accepted just like a live one.
"""
from __future__ import annotations
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.tasks.chat.streaming.flows.shared.assistant_finalize import _resolve_citations
def _registry_with_chunk(chunk_id: int = 42) -> CitationRegistry:
registry = CitationRegistry()
registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
)
return registry
def _text(value: str) -> list[dict]:
return [{"type": "text", "text": value}]
def test_known_ordinal_resolves_to_chunk_marker():
payload = _resolve_citations(
_text("Launch is March 10 [1]."), _registry_with_chunk(42)
)
assert payload[0]["text"] == "Launch is March 10 [citation:42]."
def test_unknown_ordinal_is_dropped():
payload = _resolve_citations(
_text("Unsupported claim [9]."), _registry_with_chunk(42)
)
assert payload[0]["text"] == "Unsupported claim ."
def test_foreign_citation_marker_is_preserved():
payload = _resolve_citations(
_text("From the web [citation:https://example.com]."),
_registry_with_chunk(42),
)
assert payload[0]["text"] == "From the web [citation:https://example.com]."
def test_serialized_registry_is_accepted():
serialized = _registry_with_chunk(7).model_dump()
payload = _resolve_citations(_text("See [1]."), serialized)
assert payload[0]["text"] == "See [citation:7]."
def test_empty_registry_leaves_text_untouched():
payload = _resolve_citations(_text("No sources here [1]."), CitationRegistry())
assert payload[0]["text"] == "No sources here [1]."
def test_missing_registry_is_a_noop():
payload = _resolve_citations(_text("Nothing to resolve [1]."), None)
assert payload[0]["text"] == "Nothing to resolve [1]."
def test_non_text_parts_are_left_alone():
parts = [
{"type": "tool_call", "name": "search_knowledge_base", "args": {"q": "[1]"}},
{"type": "text", "text": "Result [1]."},
]
payload = _resolve_citations(parts, _registry_with_chunk(5))
assert payload[0]["args"]["q"] == "[1]"
assert payload[1]["text"] == "Result [citation:5]."