mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
feat: enhance tool input streaming and agent action handling for improved chat experience
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
This commit is contained in:
parent
a688895115
commit
e651c41372
15 changed files with 1857 additions and 545 deletions
|
|
@ -338,6 +338,42 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None:
|
|||
)
|
||||
|
||||
|
||||
def _legacy_match_lc_id(
|
||||
pending_tool_call_chunks: list[dict[str, Any]],
|
||||
tool_name: str,
|
||||
run_id: str,
|
||||
lc_tool_call_id_by_run: dict[str, str],
|
||||
) -> str | None:
|
||||
"""Best-effort match a buffered ``tool_call_chunk`` to a tool name.
|
||||
|
||||
Pure extract of the legacy in-line match used at ``on_tool_start`` for
|
||||
parity_v2-OFF and unmatched (chunk path didn't register an index for
|
||||
this call) tools. Pops the next id-bearing chunk whose ``name``
|
||||
matches ``tool_name`` (or any id-bearing chunk as a fallback) and
|
||||
returns its id. Mutates ``pending_tool_call_chunks`` and
|
||||
``lc_tool_call_id_by_run`` in place.
|
||||
"""
|
||||
matched_idx: int | None = None
|
||||
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||
if tcc.get("name") == tool_name and tcc.get("id"):
|
||||
matched_idx = idx
|
||||
break
|
||||
if matched_idx is None:
|
||||
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||
if tcc.get("id"):
|
||||
matched_idx = idx
|
||||
break
|
||||
if matched_idx is None:
|
||||
return None
|
||||
matched = pending_tool_call_chunks.pop(matched_idx)
|
||||
candidate = matched.get("id")
|
||||
if isinstance(candidate, str) and candidate:
|
||||
if run_id:
|
||||
lc_tool_call_id_by_run[run_id] = candidate
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
async def _stream_agent_events(
|
||||
agent: Any,
|
||||
config: dict[str, Any],
|
||||
|
|
@ -403,10 +439,28 @@ async def _stream_agent_events(
|
|||
# ``tool_call_chunks`` from ``on_chat_model_stream``, key them by
|
||||
# name, and pop the next unconsumed entry at ``on_tool_start``. The
|
||||
# authoritative id is later filled in at ``on_tool_end`` from
|
||||
# ``ToolMessage.tool_call_id``.
|
||||
# ``ToolMessage.tool_call_id``. Under parity_v2 we ALSO short-circuit
|
||||
# this list for chunks that already registered into ``index_to_meta``
|
||||
# below — so this list is reserved for the parity_v2-OFF / unmatched
|
||||
# fallback path only and never re-pops a chunk we already streamed.
|
||||
pending_tool_call_chunks: list[dict[str, Any]] = []
|
||||
lc_tool_call_id_by_run: dict[str, str] = {}
|
||||
|
||||
# parity_v2 only: live tool-call argument streaming. ``index_to_meta``
|
||||
# is keyed by the chunk's ``index`` field — LangChain
|
||||
# ``ToolCallChunk``s for the same call share an index but only the
|
||||
# first chunk carries id+name (subsequent ones are id=None,
|
||||
# name=None, args="<delta>"). We register an index when both id and
|
||||
# name are observed on a chunk (per ToolCallChunk semantics they
|
||||
# arrive together on the first chunk), then route every later chunk
|
||||
# at that index to the same ``ui_id`` as a ``tool-input-delta``.
|
||||
# ``ui_tool_call_id_by_run`` maps LangGraph ``run_id`` to the
|
||||
# ``ui_id`` used for that call's ``tool-input-start`` so the matching
|
||||
# ``tool-output-available`` (emitted from ``on_tool_end``) lands on
|
||||
# the same card.
|
||||
index_to_meta: dict[int, dict[str, str]] = {}
|
||||
ui_tool_call_id_by_run: dict[str, str] = {}
|
||||
|
||||
# Per-tool-end mutable cache for the LangChain tool_call_id resolved
|
||||
# at ``on_tool_end``. ``_emit_tool_output`` reads this so every
|
||||
# ``format_tool_output_available`` call automatically carries the
|
||||
|
|
@ -452,13 +506,6 @@ async def _stream_agent_events(
|
|||
continue
|
||||
parts = _extract_chunk_parts(chunk)
|
||||
|
||||
# Accumulate any tool_call_chunks for best-effort
|
||||
# correlation with ``on_tool_start`` below. We don't emit
|
||||
# anything here; the matching is done at tool-start time.
|
||||
if parity_v2 and parts["tool_call_chunks"]:
|
||||
for tcc in parts["tool_call_chunks"]:
|
||||
pending_tool_call_chunks.append(tcc)
|
||||
|
||||
reasoning_delta = parts["reasoning"]
|
||||
text_delta = parts["text"]
|
||||
|
||||
|
|
@ -504,6 +551,71 @@ async def _stream_agent_events(
|
|||
yield streaming_service.format_text_delta(current_text_id, text_delta)
|
||||
accumulated_text += text_delta
|
||||
|
||||
# Live tool-call argument streaming. Runs AFTER text/reasoning
|
||||
# processing so chunks containing both stay in their natural
|
||||
# wire order (text → text-end → tool-input-start). Active
|
||||
# text/reasoning are closed inside the registration branch
|
||||
# before ``tool-input-start`` so the frontend sees a clean
|
||||
# part boundary even when providers interleave.
|
||||
if parity_v2 and parts["tool_call_chunks"]:
|
||||
for tcc in parts["tool_call_chunks"]:
|
||||
idx = tcc.get("index")
|
||||
|
||||
# Register this index when we first see id+name
|
||||
# TOGETHER. Per LangChain ToolCallChunk semantics the
|
||||
# first chunk for a tool call carries both fields
|
||||
# together; later chunks have id=None, name=None and
|
||||
# only ``args``. Requiring BOTH keeps wire
|
||||
# ``tool-input-start`` always carrying a real
|
||||
# toolName (assistant-ui's typed tool-part dispatch
|
||||
# keys off it).
|
||||
if idx is not None and idx not in index_to_meta:
|
||||
lc_id = tcc.get("id")
|
||||
name = tcc.get("name")
|
||||
if lc_id and name:
|
||||
ui_id = lc_id
|
||||
|
||||
# Close active text/reasoning so wire
|
||||
# ordering stays clean even on providers
|
||||
# that interleave text and tool-call chunks
|
||||
# within the same stream window.
|
||||
if current_text_id is not None:
|
||||
yield streaming_service.format_text_end(current_text_id)
|
||||
current_text_id = None
|
||||
if current_reasoning_id is not None:
|
||||
yield streaming_service.format_reasoning_end(
|
||||
current_reasoning_id
|
||||
)
|
||||
current_reasoning_id = None
|
||||
|
||||
index_to_meta[idx] = {
|
||||
"ui_id": ui_id,
|
||||
"lc_id": lc_id,
|
||||
"name": name,
|
||||
}
|
||||
yield streaming_service.format_tool_input_start(
|
||||
ui_id,
|
||||
name,
|
||||
langchain_tool_call_id=lc_id,
|
||||
)
|
||||
|
||||
# Emit args delta for any chunk at a registered
|
||||
# index (including idless continuations). Once an
|
||||
# index is owned by ``index_to_meta`` we DO NOT
|
||||
# append to ``pending_tool_call_chunks`` — that list
|
||||
# is reserved for the parity_v2-OFF / unmatched
|
||||
# fallback path so it never re-pops chunks already
|
||||
# consumed here (skip-append).
|
||||
meta = index_to_meta.get(idx) if idx is not None else None
|
||||
if meta:
|
||||
args_chunk = tcc.get("args") or ""
|
||||
if args_chunk:
|
||||
yield streaming_service.format_tool_input_delta(
|
||||
meta["ui_id"], args_chunk
|
||||
)
|
||||
else:
|
||||
pending_tool_call_chunks.append(tcc)
|
||||
|
||||
elif event_type == "on_tool_start":
|
||||
active_tool_depth += 1
|
||||
tool_name = event.get("name", "unknown_tool")
|
||||
|
|
@ -834,44 +946,65 @@ async def _stream_agent_events(
|
|||
status="in_progress",
|
||||
)
|
||||
|
||||
tool_call_id = (
|
||||
f"call_{run_id[:32]}"
|
||||
if run_id
|
||||
else streaming_service.generate_tool_call_id()
|
||||
)
|
||||
|
||||
# Best-effort attach the LangChain ``tool_call_id``. We
|
||||
# pop the first chunk in ``pending_tool_call_chunks`` whose
|
||||
# name matches; if none match (the chunked args may not yet
|
||||
# carry a ``name`` field, or the model skipped the chunked
|
||||
# form) we leave ``langchainToolCallId`` unset for now and
|
||||
# fill it in authoritatively at ``on_tool_end`` from
|
||||
# ``ToolMessage.tool_call_id``.
|
||||
langchain_tool_call_id: str | None = None
|
||||
if parity_v2 and pending_tool_call_chunks:
|
||||
matched_idx: int | None = None
|
||||
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||
if tcc.get("name") == tool_name and tcc.get("id"):
|
||||
matched_idx = idx
|
||||
# Resolve the card identity. If the chunk-emission loop
|
||||
# already registered an ``index`` for this tool call (parity_v2
|
||||
# path), reuse the same ui_id so the card sees:
|
||||
# tool-input-start → deltas… → tool-input-available →
|
||||
# tool-output-available all keyed by lc_id. Otherwise fall
|
||||
# back to the synthetic ``call_<run_id>`` id and the legacy
|
||||
# best-effort match against ``pending_tool_call_chunks``.
|
||||
matched_meta: dict[str, str] | None = None
|
||||
if parity_v2:
|
||||
# FIFO over indices 0,1,2…; first unassigned same-name
|
||||
# match wins. Handles parallel same-name calls (e.g. two
|
||||
# write_file calls) deterministically as long as the
|
||||
# model interleaves on_tool_start in the same order it
|
||||
# streamed the args.
|
||||
taken_ui_ids = set(ui_tool_call_id_by_run.values())
|
||||
for meta in index_to_meta.values():
|
||||
if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids:
|
||||
matched_meta = meta
|
||||
break
|
||||
if matched_idx is None:
|
||||
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||
if tcc.get("id"):
|
||||
matched_idx = idx
|
||||
break
|
||||
if matched_idx is not None:
|
||||
matched = pending_tool_call_chunks.pop(matched_idx)
|
||||
candidate = matched.get("id")
|
||||
if isinstance(candidate, str) and candidate:
|
||||
langchain_tool_call_id = candidate
|
||||
if run_id:
|
||||
lc_tool_call_id_by_run[run_id] = candidate
|
||||
|
||||
yield streaming_service.format_tool_input_start(
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
)
|
||||
tool_call_id: str
|
||||
langchain_tool_call_id: str | None = None
|
||||
if matched_meta is not None:
|
||||
tool_call_id = matched_meta["ui_id"]
|
||||
langchain_tool_call_id = matched_meta["lc_id"]
|
||||
# ``tool-input-start`` already fired during chunk
|
||||
# emission — skip the duplicate. No pruning is needed
|
||||
# because the chunk-emission loop intentionally never
|
||||
# appends registered-index chunks to
|
||||
# ``pending_tool_call_chunks`` (skip-append).
|
||||
if run_id:
|
||||
lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"]
|
||||
else:
|
||||
tool_call_id = (
|
||||
f"call_{run_id[:32]}"
|
||||
if run_id
|
||||
else streaming_service.generate_tool_call_id()
|
||||
)
|
||||
# Legacy fallback: parity_v2 OFF, or parity_v2 ON but the
|
||||
# provider didn't stream tool_call_chunks for this call
|
||||
# (no index registered). Run the existing best-effort
|
||||
# match BEFORE emitting start so we still attach an
|
||||
# authoritative ``langchainToolCallId`` when possible.
|
||||
if parity_v2:
|
||||
langchain_tool_call_id = _legacy_match_lc_id(
|
||||
pending_tool_call_chunks,
|
||||
tool_name,
|
||||
run_id,
|
||||
lc_tool_call_id_by_run,
|
||||
)
|
||||
yield streaming_service.format_tool_input_start(
|
||||
tool_call_id,
|
||||
tool_name,
|
||||
langchain_tool_call_id=langchain_tool_call_id,
|
||||
)
|
||||
|
||||
if run_id:
|
||||
ui_tool_call_id_by_run[run_id] = tool_call_id
|
||||
|
||||
# Sanitize tool_input: strip runtime-injected non-serializable
|
||||
# values (e.g. LangChain ToolRuntime) before sending over SSE.
|
||||
if isinstance(tool_input, dict):
|
||||
|
|
@ -924,7 +1057,15 @@ async def _stream_agent_events(
|
|||
result.write_succeeded = True
|
||||
result.verification_succeeded = True
|
||||
|
||||
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
|
||||
# Look up the SAME card id used at on_tool_start (either the
|
||||
# parity_v2 lc-id-derived ui_id or the legacy synthetic
|
||||
# ``call_<run_id>``) so the output event always lands on the
|
||||
# same card as start/delta/available. Fallback preserves the
|
||||
# legacy synthetic shape for parity_v2-OFF / unknown-run paths.
|
||||
tool_call_id = ui_tool_call_id_by_run.get(
|
||||
run_id,
|
||||
f"call_{run_id[:32]}" if run_id else "call_unknown",
|
||||
)
|
||||
original_step_id = tool_step_ids.get(
|
||||
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
|
||||
)
|
||||
|
|
@ -935,17 +1076,22 @@ async def _stream_agent_events(
|
|||
# at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``)
|
||||
# if the output isn't a ToolMessage. The value is stored in
|
||||
# ``current_lc_tool_call_id`` so ``_emit_tool_output``
|
||||
# picks it up for every output emit below. Stays None when
|
||||
# parity_v2 is off so legacy emit paths are untouched.
|
||||
# picks it up for every output emit below.
|
||||
#
|
||||
# Emitted in BOTH parity_v2 and legacy modes: the chat tool
|
||||
# card needs the LangChain id to match against the
|
||||
# ``data-action-log`` SSE event (keyed by ``lc_tool_call_id``)
|
||||
# so the inline Revert button can light up. Reading
|
||||
# ``raw_output.tool_call_id`` is a cheap, non-mutating attribute
|
||||
# access that is safe regardless of feature-flag state.
|
||||
current_lc_tool_call_id["value"] = None
|
||||
if parity_v2:
|
||||
authoritative = getattr(raw_output, "tool_call_id", None)
|
||||
if isinstance(authoritative, str) and authoritative:
|
||||
current_lc_tool_call_id["value"] = authoritative
|
||||
if run_id:
|
||||
lc_tool_call_id_by_run[run_id] = authoritative
|
||||
elif run_id and run_id in lc_tool_call_id_by_run:
|
||||
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
|
||||
authoritative = getattr(raw_output, "tool_call_id", None)
|
||||
if isinstance(authoritative, str) and authoritative:
|
||||
current_lc_tool_call_id["value"] = authoritative
|
||||
if run_id:
|
||||
lc_tool_call_id_by_run[run_id] = authoritative
|
||||
elif run_id and run_id in lc_tool_call_id_by_run:
|
||||
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
|
||||
|
||||
if tool_name == "read_file":
|
||||
yield streaming_service.format_thinking_step(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue