feat: improved agent streaming

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-29 07:20:31 -07:00
parent afb4b09cde
commit c110f5b955
60 changed files with 8068 additions and 303 deletions

View file

@ -30,6 +30,7 @@ from sqlalchemy.orm import selectinload
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.feature_flags import get_flags
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
from app.agents.new_chat.llm_config import (
AgentConfig,
@ -70,6 +71,91 @@ _background_tasks: set[asyncio.Task] = set()
_perf_log = get_perf_logger()
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
Returns a dict with three keys:
* ``text`` concatenated string content (empty string if the chunk
contributes none).
* ``reasoning`` concatenated reasoning content (empty string if the
chunk contributes none).
* ``tool_call_chunks`` flat list of LangChain ``tool_call_chunk``
dicts surfaced from either the typed-block list or the
``tool_call_chunks`` attribute.
Background
----------
``AIMessageChunk.content`` can be:
* a ``str`` (most providers), or
* a ``list`` of typed blocks ``{type: 'text' | 'reasoning' |
'tool_call_chunk' | 'tool_use' | ..., text/content/...}`` for
Anthropic, Bedrock, and several reasoning configurations.
Reasoning may also live under
``chunk.additional_kwargs['reasoning_content']`` (some providers
surface it that way instead of as a typed block). Tool-call chunks
may live under ``chunk.tool_call_chunks`` even when ``content`` is a
plain string.
Earlier versions only handled the ``isinstance(content, str)`` branch
and silently dropped reasoning blocks + tool-call chunks emitted by
LangChain ``AIMessageChunk``s.
"""
out: dict[str, Any] = {"text": "", "reasoning": "", "tool_call_chunks": []}
if chunk is None:
return out
content = getattr(chunk, "content", None)
if isinstance(content, str):
if content:
out["text"] = content
elif isinstance(content, list):
text_parts: list[str] = []
reasoning_parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
block_type = block.get("type")
if block_type == "text":
value = block.get("text") or block.get("content") or ""
if isinstance(value, str) and value:
text_parts.append(value)
elif block_type == "reasoning":
value = (
block.get("reasoning")
or block.get("text")
or block.get("content")
or ""
)
if isinstance(value, str) and value:
reasoning_parts.append(value)
elif block_type in ("tool_call_chunk", "tool_use"):
out["tool_call_chunks"].append(block)
if text_parts:
out["text"] = "".join(text_parts)
if reasoning_parts:
out["reasoning"] = "".join(reasoning_parts)
additional = getattr(chunk, "additional_kwargs", None) or {}
if isinstance(additional, dict):
extra_reasoning = additional.get("reasoning_content")
if isinstance(extra_reasoning, str) and extra_reasoning:
existing = out["reasoning"]
out["reasoning"] = (
(existing + extra_reasoning) if existing else extra_reasoning
)
extra_tool_chunks = getattr(chunk, "tool_call_chunks", None)
if isinstance(extra_tool_chunks, list):
for tcc in extra_tool_chunks:
if isinstance(tcc, dict):
out["tool_call_chunks"].append(tcc)
return out
def format_mentioned_surfsense_docs_as_context(
documents: list[SurfsenseDocsDocument],
) -> str:
@ -266,6 +352,7 @@ async def _stream_agent_events(
fallback_commit_search_space_id: int | None = None,
fallback_commit_created_by_id: str | None = None,
fallback_commit_filesystem_mode: FilesystemMode = FilesystemMode.CLOUD,
fallback_commit_thread_id: int | None = None,
) -> AsyncGenerator[str, None]:
"""Shared async generator that streams and formats astream_events from the agent.
@ -298,6 +385,41 @@ async def _stream_agent_events(
active_tool_depth: int = 0 # Track nesting: >0 means we're inside a tool
called_update_memory: bool = False
# Reasoning-block streaming. We open a reasoning block on the
# first reasoning delta of a step, append deltas as they arrive, and
# close it when text starts (the model has switched to writing its
# answer) or ``on_chat_model_end`` fires for the model node. Reuses
# the same Vercel format-helpers as text-start/delta/end.
current_reasoning_id: str | None = None
# Streaming-parity v2 feature flag. When OFF we keep the legacy
# shape: str-only content, no reasoning blocks, no
# ``langchainToolCallId`` propagation. The schema migrations
# (135 / 136) ship unconditionally because they're forward-compatible.
parity_v2 = bool(get_flags().enable_stream_parity_v2)
# Best-effort attach of LangChain ``tool_call_id`` to the synthetic
# ``call_<run_id>`` card id we already emit. We accumulate
# ``tool_call_chunks`` from ``on_chat_model_stream``, key them by
# name, and pop the next unconsumed entry at ``on_tool_start``. The
# authoritative id is later filled in at ``on_tool_end`` from
# ``ToolMessage.tool_call_id``.
pending_tool_call_chunks: list[dict[str, Any]] = []
lc_tool_call_id_by_run: dict[str, str] = {}
# Per-tool-end mutable cache for the LangChain tool_call_id resolved
# at ``on_tool_end``. ``_emit_tool_output`` reads this so every
# ``format_tool_output_available`` call automatically carries the
# authoritative id without duplicating the kwarg at every call site.
current_lc_tool_call_id: dict[str, str | None] = {"value": None}
def _emit_tool_output(call_id: str, output: Any) -> str:
return streaming_service.format_tool_output_available(
call_id,
output,
langchain_tool_call_id=current_lc_tool_call_id["value"],
)
def next_thinking_step_id() -> str:
nonlocal thinking_step_counter
thinking_step_counter += 1
@ -326,22 +448,61 @@ async def _stream_agent_events(
if "surfsense:internal" in event.get("tags", []):
continue # Suppress middleware-internal LLM tokens (e.g. KB search classification)
chunk = event.get("data", {}).get("chunk")
if chunk and hasattr(chunk, "content"):
content = chunk.content
if content and isinstance(content, str):
if current_text_id is None:
completion_event = complete_current_step()
if completion_event:
yield completion_event
if just_finished_tool:
last_active_step_id = None
last_active_step_title = ""
last_active_step_items = []
just_finished_tool = False
current_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(current_text_id)
yield streaming_service.format_text_delta(current_text_id, content)
accumulated_text += content
if not chunk:
continue
parts = _extract_chunk_parts(chunk)
# Accumulate any tool_call_chunks for best-effort
# correlation with ``on_tool_start`` below. We don't emit
# anything here; the matching is done at tool-start time.
if parity_v2 and parts["tool_call_chunks"]:
for tcc in parts["tool_call_chunks"]:
pending_tool_call_chunks.append(tcc)
reasoning_delta = parts["reasoning"]
text_delta = parts["text"]
# Reasoning streaming. Open a reasoning block on first
# delta; append every subsequent delta until text begins.
# When text starts we close the reasoning block first so the
# frontend sees the natural hand-off. Gated behind the
# parity-v2 flag so legacy deployments keep today's shape.
if parity_v2 and reasoning_delta:
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
current_text_id = None
if current_reasoning_id is None:
completion_event = complete_current_step()
if completion_event:
yield completion_event
if just_finished_tool:
last_active_step_id = None
last_active_step_title = ""
last_active_step_items = []
just_finished_tool = False
current_reasoning_id = streaming_service.generate_reasoning_id()
yield streaming_service.format_reasoning_start(current_reasoning_id)
yield streaming_service.format_reasoning_delta(
current_reasoning_id, reasoning_delta
)
if text_delta:
if current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(current_reasoning_id)
current_reasoning_id = None
if current_text_id is None:
completion_event = complete_current_step()
if completion_event:
yield completion_event
if just_finished_tool:
last_active_step_id = None
last_active_step_title = ""
last_active_step_items = []
just_finished_tool = False
current_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(current_text_id)
yield streaming_service.format_text_delta(current_text_id, text_delta)
accumulated_text += text_delta
elif event_type == "on_tool_start":
active_tool_depth += 1
@ -581,7 +742,39 @@ async def _stream_agent_events(
if run_id
else streaming_service.generate_tool_call_id()
)
yield streaming_service.format_tool_input_start(tool_call_id, tool_name)
# Best-effort attach the LangChain ``tool_call_id``. We
# pop the first chunk in ``pending_tool_call_chunks`` whose
# name matches; if none match (the chunked args may not yet
# carry a ``name`` field, or the model skipped the chunked
# form) we leave ``langchainToolCallId`` unset for now and
# fill it in authoritatively at ``on_tool_end`` from
# ``ToolMessage.tool_call_id``.
langchain_tool_call_id: str | None = None
if parity_v2 and pending_tool_call_chunks:
matched_idx: int | None = None
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("name") == tool_name and tcc.get("id"):
matched_idx = idx
break
if matched_idx is None:
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("id"):
matched_idx = idx
break
if matched_idx is not None:
matched = pending_tool_call_chunks.pop(matched_idx)
candidate = matched.get("id")
if isinstance(candidate, str) and candidate:
langchain_tool_call_id = candidate
if run_id:
lc_tool_call_id_by_run[run_id] = candidate
yield streaming_service.format_tool_input_start(
tool_call_id,
tool_name,
langchain_tool_call_id=langchain_tool_call_id,
)
# Sanitize tool_input: strip runtime-injected non-serializable
# values (e.g. LangChain ToolRuntime) before sending over SSE.
if isinstance(tool_input, dict):
@ -598,6 +791,7 @@ async def _stream_agent_events(
tool_call_id,
tool_name,
_safe_input,
langchain_tool_call_id=langchain_tool_call_id,
)
elif event_type == "on_tool_end":
@ -639,6 +833,23 @@ async def _stream_agent_events(
)
completed_step_ids.add(original_step_id)
# Authoritative LangChain tool_call_id from the returned
# ``ToolMessage``. Falls back to whatever we matched
# at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``)
# if the output isn't a ToolMessage. The value is stored in
# ``current_lc_tool_call_id`` so ``_emit_tool_output``
# picks it up for every output emit below. Stays None when
# parity_v2 is off so legacy emit paths are untouched.
current_lc_tool_call_id["value"] = None
if parity_v2:
authoritative = getattr(raw_output, "tool_call_id", None)
if isinstance(authoritative, str) and authoritative:
current_lc_tool_call_id["value"] = authoritative
if run_id:
lc_tool_call_id_by_run[run_id] = authoritative
elif run_id and run_id in lc_tool_call_id_by_run:
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
if tool_name == "read_file":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
@ -938,7 +1149,7 @@ async def _stream_agent_events(
last_active_step_items = []
if tool_name == "generate_podcast":
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
@ -963,7 +1174,7 @@ async def _stream_agent_events(
"error",
)
elif tool_name == "generate_video_presentation":
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
@ -991,7 +1202,7 @@ async def _stream_agent_events(
"error",
)
elif tool_name == "generate_image":
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
@ -1018,12 +1229,12 @@ async def _stream_agent_events(
display_output["content_preview"] = (
content[:500] + "..." if len(content) > 500 else content
)
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
display_output,
)
else:
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
{"result": tool_output},
)
@ -1051,7 +1262,7 @@ async def _stream_agent_events(
)
result_text = _tool_output_to_text(tool_output)
if _tool_output_has_error(tool_output):
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
{
"status": "error",
@ -1060,7 +1271,7 @@ async def _stream_agent_events(
},
)
else:
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
{
"status": "completed",
@ -1070,7 +1281,7 @@ async def _stream_agent_events(
)
elif tool_name == "generate_report":
# Stream the full report result so frontend can render the ReportCard
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
@ -1097,7 +1308,7 @@ async def _stream_agent_events(
"error",
)
elif tool_name == "generate_resume":
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
@ -1148,7 +1359,7 @@ async def _stream_agent_events(
"update_confluence_page",
"delete_confluence_page",
):
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
@ -1176,7 +1387,7 @@ async def _stream_agent_events(
if fpath and fpath not in result.sandbox_files:
result.sandbox_files.append(fpath)
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
{
"exit_code": exit_code,
@ -1211,12 +1422,12 @@ async def _stream_agent_events(
citations[chunk_url]["snippet"] = (
content[:200] + "" if len(content) > 200 else content
)
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
{"status": "completed", "citations": citations},
)
else:
yield streaming_service.format_tool_output_available(
yield _emit_tool_output(
tool_call_id,
{"status": "completed", "result_length": len(str(tool_output))},
)
@ -1274,6 +1485,25 @@ async def _stream_agent_events(
},
)
elif event_type == "on_custom_event" and event.get("name") == "action_log":
# Surface a freshly committed AgentActionLog row so the chat
# tool card can render its Revert button immediately.
data = event.get("data", {})
if data.get("id") is not None:
yield streaming_service.format_data("action-log", data)
elif (
event_type == "on_custom_event"
and event.get("name") == "action_log_updated"
):
# Reversibility flipped in kb_persistence after the SAVEPOINT
# for a destructive op (rm/rmdir/move/edit/write) committed.
# Frontend uses this to flip the card's Revert
# button on without re-fetching the actions list.
data = event.get("data", {})
if data.get("id") is not None:
yield streaming_service.format_data("action-log-updated", data)
elif event_type in ("on_chain_end", "on_agent_end"):
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
@ -1291,11 +1521,12 @@ async def _stream_agent_events(
# Safety net: if astream_events was cancelled before
# KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work
# (dirty_paths / staged_dirs / pending_moves) will still be in the
# checkpointed state. Run the SAME shared commit helper here so the
# turn's writes don't get lost on client disconnect, then push the
# delta back into the graph using `as_node=...` so reducers fire as if
# the after_agent hook produced it.
# (dirty_paths / staged_dirs / pending_moves / pending_deletes /
# pending_dir_deletes) will still be in the checkpointed state. Run
# the SAME shared commit helper here so the turn's writes don't get
# lost on client disconnect, then push the delta back into the graph
# using `as_node=...` so reducers fire as if the after_agent hook
# produced it.
if (
fallback_commit_filesystem_mode == FilesystemMode.CLOUD
and fallback_commit_search_space_id is not None
@ -1303,6 +1534,8 @@ async def _stream_agent_events(
(state_values.get("dirty_paths") or [])
or (state_values.get("staged_dirs") or [])
or (state_values.get("pending_moves") or [])
or (state_values.get("pending_deletes") or [])
or (state_values.get("pending_dir_deletes") or [])
)
):
try:
@ -1311,6 +1544,7 @@ async def _stream_agent_events(
search_space_id=fallback_commit_search_space_id,
created_by_id=fallback_commit_created_by_id,
filesystem_mode=fallback_commit_filesystem_mode,
thread_id=fallback_commit_thread_id,
dispatch_events=False,
)
if delta:
@ -1726,6 +1960,17 @@ async def stream_new_chat(
yield streaming_service.format_message_start()
yield streaming_service.format_start_step()
# Surface the per-turn correlation id at the very start of the
# stream so the frontend can stamp it onto the in-flight
# assistant message and replay it via ``appendMessage``
# for durable storage. Tool/action-log events DO carry it later,
# but pure-text turns never produce action-log events; this
# event guarantees the frontend learns the turn id regardless.
yield streaming_service.format_data(
"turn-info",
{"chat_turn_id": stream_result.turn_id},
)
# Initial thinking step - analyzing the request
if mentioned_surfsense_docs:
initial_title = "Analyzing referenced content"
@ -1876,6 +2121,7 @@ async def stream_new_chat(
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(
@ -2308,6 +2554,13 @@ async def stream_resume_chat(
yield streaming_service.format_message_start()
yield streaming_service.format_start_step()
# Same rationale as ``stream_new_chat``: emit the turn id so
# resumed streams can be persisted with their correlation id
# intact.
yield streaming_service.format_data(
"turn-info",
{"chat_turn_id": stream_result.turn_id},
)
_t_stream_start = time.perf_counter()
_first_event_logged = False
@ -2325,6 +2578,7 @@ async def stream_resume_chat(
if filesystem_selection
else FilesystemMode.CLOUD
),
fallback_commit_thread_id=chat_id,
):
if not _first_event_logged:
_perf_log.info(