citations: conversation-scoped registry with finalize-time [n] resolution

Add the checkpointed CitationRegistry (load/merge helpers + state field)
and a lightweight CitationStateMiddleware so subagents can register into
the same conversation registry. Resolve [n] -> [citation:<payload>] at
stream finalize from the registry, polymorphically by source type.
This commit is contained in:
CREDO23 2026-06-25 15:26:25 +02:00
parent 265888d21c
commit 04a76b163b
9 changed files with 305 additions and 1 deletions

View file

@ -102,3 +102,73 @@ def test_make_key_is_stable_and_type_prefixed() -> None:
assert key_a == key_b
assert key_a.startswith("kb_chunk|")
def _kb(registry: CitationRegistry, chunk_id: int) -> int:
return registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
)
def test_merge_unions_disjoint_registries_preserving_labels() -> None:
left = CitationRegistry()
_kb(left, 10) # [1]
_kb(left, 11) # [2]
# A branch that forked from `left`, then registered its own chunk at [3].
right = left.model_copy(deep=True)
third = _kb(right, 12) # [3]
assert third == 3
merged = left.merge(right)
assert merged.resolve(1).locator["chunk_id"] == 10
assert merged.resolve(2).locator["chunk_id"] == 11
assert merged.resolve(3).locator["chunk_id"] == 12
assert merged.next_n == 4
def test_merge_keeps_one_label_for_a_shared_source() -> None:
left = CitationRegistry()
_kb(left, 10) # [1]
right = CitationRegistry()
_kb(right, 10) # also [1], same source
merged = left.merge(right)
assert len(merged.by_n) == 1
assert merged.resolve(1).locator["chunk_id"] == 10
assert merged.next_n == 2
def test_merge_remints_on_collision_without_losing_sources() -> None:
# Two branches forked from the same base [1], each minting a *different*
# source at [2]. Merge must keep both sources, re-minting one.
base = CitationRegistry()
_kb(base, 10) # [1]
left = base.model_copy(deep=True)
_kb(left, 11) # [2] -> chunk 11
right = base.model_copy(deep=True)
_kb(right, 12) # [2] -> chunk 12 (collision)
merged = left.merge(right)
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
assert chunk_ids == {10, 11, 12}
assert merged.resolve(2).locator["chunk_id"] == 11 # left wins the slot
assert merged.resolve(3).locator["chunk_id"] == 12 # right re-minted
assert merged.next_n == 4
def test_merge_does_not_mutate_inputs() -> None:
left = CitationRegistry()
_kb(left, 10)
right = CitationRegistry()
_kb(right, 11)
left.merge(right)
assert list(left.by_n) == [1]
assert list(right.by_n) == [1]

View file

@ -0,0 +1,85 @@
"""Behavior tests for finalize-time citation resolution.
The finalize step is the single server-side seam that turns the model's bare
``[n]`` ordinals into renderer-ready ``[citation:<payload>]`` markers, using the
registry captured from the run's final state. These tests pin that contract:
known ordinals resolve, unknown ones drop, foreign markers survive, and a
serialized (dict) registry is accepted just like a live one.
"""
from __future__ import annotations
from app.agents.chat.multi_agent_chat.shared.citations import (
CitationRegistry,
CitationSourceType,
)
from app.tasks.chat.streaming.flows.shared.assistant_finalize import _resolve_citations
def _registry_with_chunk(chunk_id: int = 42) -> CitationRegistry:
registry = CitationRegistry()
registry.register(
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
)
return registry
def _text(value: str) -> list[dict]:
return [{"type": "text", "text": value}]
def test_known_ordinal_resolves_to_chunk_marker():
payload = _resolve_citations(
_text("Launch is March 10 [1]."), _registry_with_chunk(42)
)
assert payload[0]["text"] == "Launch is March 10 [citation:42]."
def test_unknown_ordinal_is_dropped():
payload = _resolve_citations(
_text("Unsupported claim [9]."), _registry_with_chunk(42)
)
assert payload[0]["text"] == "Unsupported claim ."
def test_foreign_citation_marker_is_preserved():
payload = _resolve_citations(
_text("From the web [citation:https://example.com]."),
_registry_with_chunk(42),
)
assert payload[0]["text"] == "From the web [citation:https://example.com]."
def test_serialized_registry_is_accepted():
serialized = _registry_with_chunk(7).model_dump()
payload = _resolve_citations(_text("See [1]."), serialized)
assert payload[0]["text"] == "See [citation:7]."
def test_empty_registry_leaves_text_untouched():
payload = _resolve_citations(_text("No sources here [1]."), CitationRegistry())
assert payload[0]["text"] == "No sources here [1]."
def test_missing_registry_is_a_noop():
payload = _resolve_citations(_text("Nothing to resolve [1]."), None)
assert payload[0]["text"] == "Nothing to resolve [1]."
def test_non_text_parts_are_left_alone():
parts = [
{"type": "tool_call", "name": "search_knowledge_base", "args": {"q": "[1]"}},
{"type": "text", "text": "Result [1]."},
]
payload = _resolve_citations(parts, _registry_with_chunk(5))
assert payload[0]["args"]["q"] == "[1]"
assert payload[1]["text"] == "Result [citation:5]."