Merge span metadata into persisted tool-call and thinking parts.

This commit is contained in:
CREDO23 2026-05-08 22:48:07 +02:00
parent e802de2333
commit 3dbcac4b9d

View file

@ -51,6 +51,15 @@ logger = logging.getLogger(__name__)
_MEANINGFUL_PART_TYPES: frozenset[str] = frozenset({"text", "reasoning", "tool-call"})
def _merge_tool_part_metadata(part: dict[str, Any], metadata: dict[str, Any] | None) -> None:
if not metadata:
return
md = part.setdefault("metadata", {})
for k, v in metadata.items():
if k not in md:
md[k] = v
class AssistantContentBuilder:
"""Server-side projection of ``surfsense_web/lib/chat/streaming-state.ts``.
@ -177,6 +186,8 @@ class AssistantContentBuilder:
ui_id: str,
tool_name: str,
langchain_tool_call_id: str | None,
*,
metadata: dict[str, Any] | None = None,
) -> None:
"""Register a tool-call card. Args are filled in by later events."""
if not ui_id:
@ -187,11 +198,11 @@ class AssistantContentBuilder:
# (the canonical path). The FE de-dupes via ``toolCallIndices``;
# we mirror that here.
if ui_id in self._tool_call_idx_by_ui_id:
if langchain_tool_call_id:
idx = self._tool_call_idx_by_ui_id[ui_id]
part = self.parts[idx]
if not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
idx = self._tool_call_idx_by_ui_id[ui_id]
part = self.parts[idx]
if langchain_tool_call_id and not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
_merge_tool_part_metadata(part, metadata)
return
part: dict[str, Any] = {
@ -202,6 +213,8 @@ class AssistantContentBuilder:
}
if langchain_tool_call_id:
part["langchainToolCallId"] = langchain_tool_call_id
if metadata:
part["metadata"] = dict(metadata)
self.parts.append(part)
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
@ -235,6 +248,8 @@ class AssistantContentBuilder:
tool_name: str,
args: dict[str, Any],
langchain_tool_call_id: str | None,
*,
metadata: dict[str, Any] | None = None,
) -> None:
"""Finalize the tool-call card's input.
@ -264,6 +279,7 @@ class AssistantContentBuilder:
part["argsText"] = final_args_text
if langchain_tool_call_id and not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
_merge_tool_part_metadata(part, metadata)
return
# No prior tool-input-start: register the card now.
@ -276,6 +292,7 @@ class AssistantContentBuilder:
}
if langchain_tool_call_id:
new_part["langchainToolCallId"] = langchain_tool_call_id
_merge_tool_part_metadata(new_part, metadata)
self.parts.append(new_part)
self._tool_call_idx_by_ui_id[ui_id] = len(self.parts) - 1
@ -287,6 +304,8 @@ class AssistantContentBuilder:
ui_id: str,
output: Any,
langchain_tool_call_id: str | None,
*,
metadata: dict[str, Any] | None = None,
) -> None:
"""Attach the tool's output (``result``) to the matching card.
@ -305,6 +324,7 @@ class AssistantContentBuilder:
part["result"] = output
if langchain_tool_call_id and not part.get("langchainToolCallId"):
part["langchainToolCallId"] = langchain_tool_call_id
_merge_tool_part_metadata(part, metadata)
# ------------------------------------------------------------------
# Thinking steps & step separators
@ -316,6 +336,8 @@ class AssistantContentBuilder:
title: str,
status: str,
items: list[str] | None,
*,
metadata: dict[str, Any] | None = None,
) -> None:
"""Update / insert the singleton ``data-thinking-steps`` part.
@ -328,12 +350,14 @@ class AssistantContentBuilder:
if not step_id:
return
new_step = {
new_step: dict[str, Any] = {
"id": step_id,
"title": title or "",
"status": status or "in_progress",
"items": list(items) if items else [],
}
if metadata:
new_step["metadata"] = dict(metadata)
# Find existing data-thinking-steps part.
existing_idx = -1
@ -347,6 +371,8 @@ class AssistantContentBuilder:
replaced = False
for i, step in enumerate(current_steps):
if step.get("id") == step_id:
if not metadata and step.get("metadata"):
new_step["metadata"] = dict(step["metadata"])
current_steps[i] = new_step
replaced = True
break