feat: enhance tool input streaming and agent action handling for improved chat experience
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-30 03:13:58 -07:00
parent a688895115
commit e651c41372
15 changed files with 1857 additions and 545 deletions

View file

@ -338,6 +338,42 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
)
def _legacy_match_lc_id(
pending_tool_call_chunks: list[dict[str, Any]],
tool_name: str,
run_id: str,
lc_tool_call_id_by_run: dict[str, str],
) -> str | None:
"""Best-effort match a buffered ``tool_call_chunk`` to a tool name.
Pure extract of the legacy in-line match used at ``on_tool_start`` for
parity_v2-OFF and unmatched (chunk path didn't register an index for
this call) tools. Pops the next id-bearing chunk whose ``name``
matches ``tool_name`` (or any id-bearing chunk as a fallback) and
returns its id. Mutates ``pending_tool_call_chunks`` and
``lc_tool_call_id_by_run`` in place.
"""
matched_idx: int | None = None
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("name") == tool_name and tcc.get("id"):
matched_idx = idx
break
if matched_idx is None:
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("id"):
matched_idx = idx
break
if matched_idx is None:
return None
matched = pending_tool_call_chunks.pop(matched_idx)
candidate = matched.get("id")
if isinstance(candidate, str) and candidate:
if run_id:
lc_tool_call_id_by_run[run_id] = candidate
return candidate
return None
async def _stream_agent_events(
agent: Any,
config: dict[str, Any],
@ -403,10 +439,28 @@ async def _stream_agent_events(
# ``tool_call_chunks`` from ``on_chat_model_stream``, key them by
# name, and pop the next unconsumed entry at ``on_tool_start``. The
# authoritative id is later filled in at ``on_tool_end`` from
# ``ToolMessage.tool_call_id``.
# ``ToolMessage.tool_call_id``. Under parity_v2 we ALSO short-circuit
# this list for chunks that already registered into ``index_to_meta``
# below — so this list is reserved for the parity_v2-OFF / unmatched
# fallback path only and never re-pops a chunk we already streamed.
pending_tool_call_chunks: list[dict[str, Any]] = []
lc_tool_call_id_by_run: dict[str, str] = {}
# parity_v2 only: live tool-call argument streaming. ``index_to_meta``
# is keyed by the chunk's ``index`` field — LangChain
# ``ToolCallChunk``s for the same call share an index but only the
# first chunk carries id+name (subsequent ones are id=None,
# name=None, args="<delta>"). We register an index when both id and
# name are observed on a chunk (per ToolCallChunk semantics they
# arrive together on the first chunk), then route every later chunk
# at that index to the same ``ui_id`` as a ``tool-input-delta``.
# ``ui_tool_call_id_by_run`` maps LangGraph ``run_id`` to the
# ``ui_id`` used for that call's ``tool-input-start`` so the matching
# ``tool-output-available`` (emitted from ``on_tool_end``) lands on
# the same card.
index_to_meta: dict[int, dict[str, str]] = {}
ui_tool_call_id_by_run: dict[str, str] = {}
# Per-tool-end mutable cache for the LangChain tool_call_id resolved
# at ``on_tool_end``. ``_emit_tool_output`` reads this so every
# ``format_tool_output_available`` call automatically carries the
@ -452,13 +506,6 @@ async def _stream_agent_events(
continue
parts = _extract_chunk_parts(chunk)
# Accumulate any tool_call_chunks for best-effort
# correlation with ``on_tool_start`` below. We don't emit
# anything here; the matching is done at tool-start time.
if parity_v2 and parts["tool_call_chunks"]:
for tcc in parts["tool_call_chunks"]:
pending_tool_call_chunks.append(tcc)
reasoning_delta = parts["reasoning"]
text_delta = parts["text"]
@ -504,6 +551,71 @@ async def _stream_agent_events(
yield streaming_service.format_text_delta(current_text_id, text_delta)
accumulated_text += text_delta
# Live tool-call argument streaming. Runs AFTER text/reasoning
# processing so chunks containing both stay in their natural
# wire order (text → text-end → tool-input-start). Active
# text/reasoning are closed inside the registration branch
# before ``tool-input-start`` so the frontend sees a clean
# part boundary even when providers interleave.
if parity_v2 and parts["tool_call_chunks"]:
for tcc in parts["tool_call_chunks"]:
idx = tcc.get("index")
# Register this index when we first see id+name
# TOGETHER. Per LangChain ToolCallChunk semantics the
# first chunk for a tool call carries both fields
# together; later chunks have id=None, name=None and
# only ``args``. Requiring BOTH keeps wire
# ``tool-input-start`` always carrying a real
# toolName (assistant-ui's typed tool-part dispatch
# keys off it).
if idx is not None and idx not in index_to_meta:
lc_id = tcc.get("id")
name = tcc.get("name")
if lc_id and name:
ui_id = lc_id
# Close active text/reasoning so wire
# ordering stays clean even on providers
# that interleave text and tool-call chunks
# within the same stream window.
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
current_text_id = None
if current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(
current_reasoning_id
)
current_reasoning_id = None
index_to_meta[idx] = {
"ui_id": ui_id,
"lc_id": lc_id,
"name": name,
}
yield streaming_service.format_tool_input_start(
ui_id,
name,
langchain_tool_call_id=lc_id,
)
# Emit args delta for any chunk at a registered
# index (including idless continuations). Once an
# index is owned by ``index_to_meta`` we DO NOT
# append to ``pending_tool_call_chunks`` — that list
# is reserved for the parity_v2-OFF / unmatched
# fallback path so it never re-pops chunks already
# consumed here (skip-append).
meta = index_to_meta.get(idx) if idx is not None else None
if meta:
args_chunk = tcc.get("args") or ""
if args_chunk:
yield streaming_service.format_tool_input_delta(
meta["ui_id"], args_chunk
)
else:
pending_tool_call_chunks.append(tcc)
elif event_type == "on_tool_start":
active_tool_depth += 1
tool_name = event.get("name", "unknown_tool")
@ -834,44 +946,65 @@ async def _stream_agent_events(
status="in_progress",
)
tool_call_id = (
f"call_{run_id[:32]}"
if run_id
else streaming_service.generate_tool_call_id()
)
# Best-effort attach the LangChain ``tool_call_id``. We
# pop the first chunk in ``pending_tool_call_chunks`` whose
# name matches; if none match (the chunked args may not yet
# carry a ``name`` field, or the model skipped the chunked
# form) we leave ``langchainToolCallId`` unset for now and
# fill it in authoritatively at ``on_tool_end`` from
# ``ToolMessage.tool_call_id``.
langchain_tool_call_id: str | None = None
if parity_v2 and pending_tool_call_chunks:
matched_idx: int | None = None
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("name") == tool_name and tcc.get("id"):
matched_idx = idx
# Resolve the card identity. If the chunk-emission loop
# already registered an ``index`` for this tool call (parity_v2
# path), reuse the same ui_id so the card sees:
# tool-input-start → deltas… → tool-input-available →
# tool-output-available all keyed by lc_id. Otherwise fall
# back to the synthetic ``call_<run_id>`` id and the legacy
# best-effort match against ``pending_tool_call_chunks``.
matched_meta: dict[str, str] | None = None
if parity_v2:
# FIFO over indices 0,1,2…; first unassigned same-name
# match wins. Handles parallel same-name calls (e.g. two
# write_file calls) deterministically as long as the
# model interleaves on_tool_start in the same order it
# streamed the args.
taken_ui_ids = set(ui_tool_call_id_by_run.values())
for meta in index_to_meta.values():
if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids:
matched_meta = meta
break
if matched_idx is None:
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("id"):
matched_idx = idx
break
if matched_idx is not None:
matched = pending_tool_call_chunks.pop(matched_idx)
candidate = matched.get("id")
if isinstance(candidate, str) and candidate:
langchain_tool_call_id = candidate
if run_id:
lc_tool_call_id_by_run[run_id] = candidate
yield streaming_service.format_tool_input_start(
tool_call_id,
tool_name,
langchain_tool_call_id=langchain_tool_call_id,
)
tool_call_id: str
langchain_tool_call_id: str | None = None
if matched_meta is not None:
tool_call_id = matched_meta["ui_id"]
langchain_tool_call_id = matched_meta["lc_id"]
# ``tool-input-start`` already fired during chunk
# emission — skip the duplicate. No pruning is needed
# because the chunk-emission loop intentionally never
# appends registered-index chunks to
# ``pending_tool_call_chunks`` (skip-append).
if run_id:
lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"]
else:
tool_call_id = (
f"call_{run_id[:32]}"
if run_id
else streaming_service.generate_tool_call_id()
)
# Legacy fallback: parity_v2 OFF, or parity_v2 ON but the
# provider didn't stream tool_call_chunks for this call
# (no index registered). Run the existing best-effort
# match BEFORE emitting start so we still attach an
# authoritative ``langchainToolCallId`` when possible.
if parity_v2:
langchain_tool_call_id = _legacy_match_lc_id(
pending_tool_call_chunks,
tool_name,
run_id,
lc_tool_call_id_by_run,
)
yield streaming_service.format_tool_input_start(
tool_call_id,
tool_name,
langchain_tool_call_id=langchain_tool_call_id,
)
if run_id:
ui_tool_call_id_by_run[run_id] = tool_call_id
# Sanitize tool_input: strip runtime-injected non-serializable
# values (e.g. LangChain ToolRuntime) before sending over SSE.
if isinstance(tool_input, dict):
@ -924,7 +1057,15 @@ async def _stream_agent_events(
result.write_succeeded = True
result.verification_succeeded = True
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
# Look up the SAME card id used at on_tool_start (either the
# parity_v2 lc-id-derived ui_id or the legacy synthetic
# ``call_<run_id>``) so the output event always lands on the
# same card as start/delta/available. Fallback preserves the
# legacy synthetic shape for parity_v2-OFF / unknown-run paths.
tool_call_id = ui_tool_call_id_by_run.get(
run_id,
f"call_{run_id[:32]}" if run_id else "call_unknown",
)
original_step_id = tool_step_ids.get(
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
)
@ -935,17 +1076,22 @@ async def _stream_agent_events(
# at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``)
# if the output isn't a ToolMessage. The value is stored in
# ``current_lc_tool_call_id`` so ``_emit_tool_output``
# picks it up for every output emit below. Stays None when
# parity_v2 is off so legacy emit paths are untouched.
# picks it up for every output emit below.
#
# Emitted in BOTH parity_v2 and legacy modes: the chat tool
# card needs the LangChain id to match against the
# ``data-action-log`` SSE event (keyed by ``lc_tool_call_id``)
# so the inline Revert button can light up. Reading
# ``raw_output.tool_call_id`` is a cheap, non-mutating attribute
# access that is safe regardless of feature-flag state.
current_lc_tool_call_id["value"] = None
if parity_v2:
authoritative = getattr(raw_output, "tool_call_id", None)
if isinstance(authoritative, str) and authoritative:
current_lc_tool_call_id["value"] = authoritative
if run_id:
lc_tool_call_id_by_run[run_id] = authoritative
elif run_id and run_id in lc_tool_call_id_by_run:
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
authoritative = getattr(raw_output, "tool_call_id", None)
if isinstance(authoritative, str) and authoritative:
current_lc_tool_call_id["value"] = authoritative
if run_id:
lc_tool_call_id_by_run[run_id] = authoritative
elif run_id and run_id in lc_tool_call_id_by_run:
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
if tool_name == "read_file":
yield streaming_service.format_thinking_step(