mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-28 21:49:40 +02:00
citations: conversation-scoped registry with finalize-time [n] resolution
Add the checkpointed CitationRegistry (load/merge helpers + state field) and a lightweight CitationStateMiddleware so subagents can register into the same conversation registry. Resolve [n] -> [citation:<payload>] at stream finalize from the registry, polymorphically by source type.
This commit is contained in:
parent
265888d21c
commit
04a76b163b
9 changed files with 305 additions and 1 deletions
|
|
@ -9,11 +9,13 @@ from .markers import to_frontend_payload
|
|||
from .models import CitationEntry, CitationSourceType
|
||||
from .normalizer import normalize_citations
|
||||
from .registry import CitationRegistry, make_key
|
||||
from .state import load_registry
|
||||
|
||||
__all__ = [
|
||||
"CitationEntry",
|
||||
"CitationRegistry",
|
||||
"CitationSourceType",
|
||||
"load_registry",
|
||||
"make_key",
|
||||
"normalize_citations",
|
||||
"to_frontend_payload",
|
||||
|
|
|
|||
|
|
@ -57,5 +57,35 @@ class CitationRegistry(BaseModel):
|
|||
"""Map ``[n]`` back to its source; unknown → ``None`` so bad citations drop."""
|
||||
return self.by_n.get(n)
|
||||
|
||||
def merge(self, other: CitationRegistry) -> CitationRegistry:
|
||||
"""Union ``self`` with ``other`` (find-or-create), returning a new registry.
|
||||
|
||||
Needed because separate branches (parent + subagents, parallel tool calls)
|
||||
each register into a registry forked from the same base. A plain replace
|
||||
would drop one branch's mappings; this unions them so ``[n]`` stays globally
|
||||
consistent and no source is lost:
|
||||
|
||||
- A source already in ``self`` keeps its existing ``[n]``.
|
||||
- A source only in ``other`` keeps its ``[n]`` when that slot is free.
|
||||
- A collision (same ``[n]``, different source on each side) re-mints the
|
||||
``other`` entry to a fresh ``[n]`` and advances ``next_n`` past both.
|
||||
|
||||
Pure: neither registry is mutated. Entries are folded in ascending ``[n]``
|
||||
order so the result is deterministic.
|
||||
"""
|
||||
merged = self.model_copy(deep=True)
|
||||
for n in sorted(other.by_n):
|
||||
entry = other.by_n[n]
|
||||
key = make_key(entry.source_type, entry.locator)
|
||||
if key in merged.by_key:
|
||||
continue
|
||||
if n in merged.by_n:
|
||||
merged.register(entry.source_type, entry.locator, entry.display)
|
||||
else:
|
||||
merged.by_n[n] = entry.model_copy(deep=True)
|
||||
merged.by_key[key] = n
|
||||
merged.next_n = max(merged.next_n, n + 1)
|
||||
return merged
|
||||
|
||||
|
||||
__all__ = ["CitationRegistry", "make_key"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
"""Read the conversation's ``CitationRegistry`` out of graph state.
|
||||
|
||||
The registry is checkpointed, so it may come back as a live ``CitationRegistry``
|
||||
or a plain dict (after (de)serialization). Both the search tool and the read
|
||||
path load it the same way before registering new ``[n]`` and writing it back.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from .registry import CitationRegistry
|
||||
|
||||
|
||||
def load_registry(state: Mapping[str, Any] | None) -> CitationRegistry:
|
||||
"""Return the registry from ``state``, tolerating a serialized dict or absence."""
|
||||
raw = state.get("citation_registry") if state else None
|
||||
if isinstance(raw, CitationRegistry):
|
||||
return raw
|
||||
if isinstance(raw, dict):
|
||||
return CitationRegistry.model_validate(raw)
|
||||
return CitationRegistry()
|
||||
|
||||
|
||||
__all__ = ["load_registry"]
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""Contribute the ``citation_registry`` state channel to a subagent.
|
||||
|
||||
The conversation's ``[n]`` -> source registry lives on graph state behind a
|
||||
merge reducer (see :mod:`app.agents.chat.multi_agent_chat.shared.state.reducers`).
|
||||
The orchestrator and the KB subagent get that channel for free via the filesystem
|
||||
state schema, but a citable subagent that does *not* use the filesystem (e.g.
|
||||
``research``) still needs the channel declared so its tools can register ``[n]``
|
||||
via ``Command(update={"citation_registry": ...})`` and have it merge back up.
|
||||
|
||||
This middleware adds *only* that channel — no tools, no behavior — so any subagent
|
||||
that mints citations can opt in without inheriting filesystem semantics.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, NotRequired
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
|
||||
from app.agents.chat.multi_agent_chat.shared.state.reducers import (
|
||||
_citation_registry_merge_reducer,
|
||||
)
|
||||
|
||||
|
||||
class CitationState(TypedDict):
|
||||
"""State carrying just the per-conversation ``[n]`` -> source registry."""
|
||||
|
||||
citation_registry: NotRequired[
|
||||
Annotated[CitationRegistry, _citation_registry_merge_reducer]
|
||||
]
|
||||
|
||||
|
||||
class CitationStateMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Declare the ``citation_registry`` channel; no tools, no hooks."""
|
||||
|
||||
tools = ()
|
||||
state_schema = CitationState
|
||||
|
||||
|
||||
def build_citation_state_mw() -> CitationStateMiddleware:
|
||||
return CitationStateMiddleware()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CitationState",
|
||||
"CitationStateMiddleware",
|
||||
"build_citation_state_mw",
|
||||
]
|
||||
|
|
@ -81,6 +81,7 @@ async def stream_agent_events(
|
|||
result.final_message_parts = final_assistant_parts_from_messages(
|
||||
state_values.get("messages")
|
||||
)
|
||||
result.citation_registry = state_values.get("citation_registry")
|
||||
|
||||
# Safety net: if astream_events was cancelled before
|
||||
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
|
||||
|
|
|
|||
|
|
@ -22,8 +22,12 @@ Never raises (best-effort, logs only).
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import (
|
||||
CitationRegistry,
|
||||
normalize_citations,
|
||||
)
|
||||
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
|
|
@ -33,6 +37,35 @@ if TYPE_CHECKING:
|
|||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
def _as_registry(raw: Any) -> CitationRegistry | None:
|
||||
"""Coerce the captured state value into a registry, tolerating a serialized dict."""
|
||||
if isinstance(raw, CitationRegistry):
|
||||
return raw
|
||||
if isinstance(raw, dict):
|
||||
try:
|
||||
return CitationRegistry.model_validate(raw)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_citations(
|
||||
content_payload: list[dict[str, Any]], raw_registry: Any
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Rewrite ``[n]`` -> ``[citation:<payload>]`` in each text part before persisting.
|
||||
|
||||
No-op when the turn registered no citable sources; ``web_search``'s existing
|
||||
``[citation:url]`` markers pass through untouched (the regex matches bare ``[n]``).
|
||||
"""
|
||||
registry = _as_registry(raw_registry)
|
||||
if registry is None or not registry.by_n:
|
||||
return content_payload
|
||||
for part in content_payload:
|
||||
if part.get("type") == "text" and isinstance(part.get("text"), str):
|
||||
part["text"] = normalize_citations(part["text"], registry)
|
||||
return content_payload
|
||||
|
||||
|
||||
async def finalize_assistant_message(
|
||||
*,
|
||||
stream_result: StreamResult | None,
|
||||
|
|
@ -79,6 +112,9 @@ async def finalize_assistant_message(
|
|||
content_payload,
|
||||
stream_result.final_message_parts,
|
||||
)
|
||||
content_payload = _resolve_citations(
|
||||
content_payload, stream_result.citation_registry
|
||||
)
|
||||
|
||||
if builder_stats is not None:
|
||||
_perf_log.info(
|
||||
|
|
|
|||
|
|
@ -39,3 +39,7 @@ class StreamResult:
|
|||
# state. Used after streaming completes as a provider-agnostic persistence
|
||||
# backfill when no text chunks reached the live stream.
|
||||
final_message_parts: list[dict[str, Any]] = field(default_factory=list)
|
||||
# Per-conversation citation registry captured from the final LangGraph state
|
||||
# (a ``CitationRegistry`` or its serialized dict). Read at finalize to rewrite
|
||||
# the model's ``[n]`` ordinals into ``[citation:<payload>]`` markers.
|
||||
citation_registry: Any | None = field(default=None, repr=False)
|
||||
|
|
|
|||
|
|
@ -102,3 +102,73 @@ def test_make_key_is_stable_and_type_prefixed() -> None:
|
|||
|
||||
assert key_a == key_b
|
||||
assert key_a.startswith("kb_chunk|")
|
||||
|
||||
|
||||
def _kb(registry: CitationRegistry, chunk_id: int) -> int:
|
||||
return registry.register(
|
||||
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
|
||||
)
|
||||
|
||||
|
||||
def test_merge_unions_disjoint_registries_preserving_labels() -> None:
|
||||
left = CitationRegistry()
|
||||
_kb(left, 10) # [1]
|
||||
_kb(left, 11) # [2]
|
||||
|
||||
# A branch that forked from `left`, then registered its own chunk at [3].
|
||||
right = left.model_copy(deep=True)
|
||||
third = _kb(right, 12) # [3]
|
||||
assert third == 3
|
||||
|
||||
merged = left.merge(right)
|
||||
|
||||
assert merged.resolve(1).locator["chunk_id"] == 10
|
||||
assert merged.resolve(2).locator["chunk_id"] == 11
|
||||
assert merged.resolve(3).locator["chunk_id"] == 12
|
||||
assert merged.next_n == 4
|
||||
|
||||
|
||||
def test_merge_keeps_one_label_for_a_shared_source() -> None:
|
||||
left = CitationRegistry()
|
||||
_kb(left, 10) # [1]
|
||||
right = CitationRegistry()
|
||||
_kb(right, 10) # also [1], same source
|
||||
|
||||
merged = left.merge(right)
|
||||
|
||||
assert len(merged.by_n) == 1
|
||||
assert merged.resolve(1).locator["chunk_id"] == 10
|
||||
assert merged.next_n == 2
|
||||
|
||||
|
||||
def test_merge_remints_on_collision_without_losing_sources() -> None:
|
||||
# Two branches forked from the same base [1], each minting a *different*
|
||||
# source at [2]. Merge must keep both sources, re-minting one.
|
||||
base = CitationRegistry()
|
||||
_kb(base, 10) # [1]
|
||||
|
||||
left = base.model_copy(deep=True)
|
||||
_kb(left, 11) # [2] -> chunk 11
|
||||
|
||||
right = base.model_copy(deep=True)
|
||||
_kb(right, 12) # [2] -> chunk 12 (collision)
|
||||
|
||||
merged = left.merge(right)
|
||||
|
||||
chunk_ids = {entry.locator["chunk_id"] for entry in merged.by_n.values()}
|
||||
assert chunk_ids == {10, 11, 12}
|
||||
assert merged.resolve(2).locator["chunk_id"] == 11 # left wins the slot
|
||||
assert merged.resolve(3).locator["chunk_id"] == 12 # right re-minted
|
||||
assert merged.next_n == 4
|
||||
|
||||
|
||||
def test_merge_does_not_mutate_inputs() -> None:
|
||||
left = CitationRegistry()
|
||||
_kb(left, 10)
|
||||
right = CitationRegistry()
|
||||
_kb(right, 11)
|
||||
|
||||
left.merge(right)
|
||||
|
||||
assert list(left.by_n) == [1]
|
||||
assert list(right.by_n) == [1]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,85 @@
|
|||
"""Behavior tests for finalize-time citation resolution.
|
||||
|
||||
The finalize step is the single server-side seam that turns the model's bare
|
||||
``[n]`` ordinals into renderer-ready ``[citation:<payload>]`` markers, using the
|
||||
registry captured from the run's final state. These tests pin that contract:
|
||||
known ordinals resolve, unknown ones drop, foreign markers survive, and a
|
||||
serialized (dict) registry is accepted just like a live one.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import (
|
||||
CitationRegistry,
|
||||
CitationSourceType,
|
||||
)
|
||||
from app.tasks.chat.streaming.flows.shared.assistant_finalize import _resolve_citations
|
||||
|
||||
|
||||
def _registry_with_chunk(chunk_id: int = 42) -> CitationRegistry:
|
||||
registry = CitationRegistry()
|
||||
registry.register(
|
||||
CitationSourceType.KB_CHUNK, {"document_id": 1, "chunk_id": chunk_id}
|
||||
)
|
||||
return registry
|
||||
|
||||
|
||||
def _text(value: str) -> list[dict]:
|
||||
return [{"type": "text", "text": value}]
|
||||
|
||||
|
||||
def test_known_ordinal_resolves_to_chunk_marker():
|
||||
payload = _resolve_citations(
|
||||
_text("Launch is March 10 [1]."), _registry_with_chunk(42)
|
||||
)
|
||||
|
||||
assert payload[0]["text"] == "Launch is March 10 [citation:42]."
|
||||
|
||||
|
||||
def test_unknown_ordinal_is_dropped():
|
||||
payload = _resolve_citations(
|
||||
_text("Unsupported claim [9]."), _registry_with_chunk(42)
|
||||
)
|
||||
|
||||
assert payload[0]["text"] == "Unsupported claim ."
|
||||
|
||||
|
||||
def test_foreign_citation_marker_is_preserved():
|
||||
payload = _resolve_citations(
|
||||
_text("From the web [citation:https://example.com]."),
|
||||
_registry_with_chunk(42),
|
||||
)
|
||||
|
||||
assert payload[0]["text"] == "From the web [citation:https://example.com]."
|
||||
|
||||
|
||||
def test_serialized_registry_is_accepted():
|
||||
serialized = _registry_with_chunk(7).model_dump()
|
||||
|
||||
payload = _resolve_citations(_text("See [1]."), serialized)
|
||||
|
||||
assert payload[0]["text"] == "See [citation:7]."
|
||||
|
||||
|
||||
def test_empty_registry_leaves_text_untouched():
|
||||
payload = _resolve_citations(_text("No sources here [1]."), CitationRegistry())
|
||||
|
||||
assert payload[0]["text"] == "No sources here [1]."
|
||||
|
||||
|
||||
def test_missing_registry_is_a_noop():
|
||||
payload = _resolve_citations(_text("Nothing to resolve [1]."), None)
|
||||
|
||||
assert payload[0]["text"] == "Nothing to resolve [1]."
|
||||
|
||||
|
||||
def test_non_text_parts_are_left_alone():
|
||||
parts = [
|
||||
{"type": "tool_call", "name": "search_knowledge_base", "args": {"q": "[1]"}},
|
||||
{"type": "text", "text": "Result [1]."},
|
||||
]
|
||||
|
||||
payload = _resolve_citations(parts, _registry_with_chunk(5))
|
||||
|
||||
assert payload[0]["args"]["q"] == "[1]"
|
||||
assert payload[1]["text"] == "Result [citation:5]."
|
||||
Loading…
Add table
Add a link
Reference in a new issue