mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-11 08:42:39 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/split-auto-free-premium
This commit is contained in:
commit
872065f90d
15 changed files with 1857 additions and 545 deletions
|
|
@ -599,8 +599,17 @@ class VercelStreamingService:
|
||||||
Format the start of tool input streaming.
|
Format the start of tool input streaming.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_call_id: The unique tool call identifier (synthetic, derived
|
tool_call_id: The unique tool call identifier. May be EITHER the
|
||||||
from LangGraph ``run_id`` so the frontend has a stable card id).
|
synthetic ``call_<run_id>`` id derived from LangGraph
|
||||||
|
``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2``
|
||||||
|
OFF, or the unmatched-fallback path under parity_v2) OR
|
||||||
|
the authoritative LangChain ``tool_call.id`` (parity_v2
|
||||||
|
path: when the provider streams ``tool_call_chunks`` we
|
||||||
|
register the ``index`` and reuse the lc-id as the card
|
||||||
|
id so live ``tool-input-delta`` events can be routed
|
||||||
|
without a downstream join). Either way, the same id is
|
||||||
|
preserved across ``tool-input-start`` / ``-delta`` /
|
||||||
|
``-available`` / ``tool-output-available`` for one call.
|
||||||
tool_name: The name of the tool being called.
|
tool_name: The name of the tool being called.
|
||||||
langchain_tool_call_id: Optional authoritative LangChain
|
langchain_tool_call_id: Optional authoritative LangChain
|
||||||
``tool_call.id``. When set, surfaces as
|
``tool_call.id``. When set, surfaces as
|
||||||
|
|
|
||||||
|
|
@ -473,6 +473,42 @@ def _emit_stream_terminal_error(
|
||||||
return streaming_service.format_error(message, error_code=error_code)
|
return streaming_service.format_error(message, error_code=error_code)
|
||||||
|
|
||||||
|
|
||||||
|
def _legacy_match_lc_id(
|
||||||
|
pending_tool_call_chunks: list[dict[str, Any]],
|
||||||
|
tool_name: str,
|
||||||
|
run_id: str,
|
||||||
|
lc_tool_call_id_by_run: dict[str, str],
|
||||||
|
) -> str | None:
|
||||||
|
"""Best-effort match a buffered ``tool_call_chunk`` to a tool name.
|
||||||
|
|
||||||
|
Pure extract of the legacy in-line match used at ``on_tool_start`` for
|
||||||
|
parity_v2-OFF and unmatched (chunk path didn't register an index for
|
||||||
|
this call) tools. Pops the next id-bearing chunk whose ``name``
|
||||||
|
matches ``tool_name`` (or any id-bearing chunk as a fallback) and
|
||||||
|
returns its id. Mutates ``pending_tool_call_chunks`` and
|
||||||
|
``lc_tool_call_id_by_run`` in place.
|
||||||
|
"""
|
||||||
|
matched_idx: int | None = None
|
||||||
|
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||||
|
if tcc.get("name") == tool_name and tcc.get("id"):
|
||||||
|
matched_idx = idx
|
||||||
|
break
|
||||||
|
if matched_idx is None:
|
||||||
|
for idx, tcc in enumerate(pending_tool_call_chunks):
|
||||||
|
if tcc.get("id"):
|
||||||
|
matched_idx = idx
|
||||||
|
break
|
||||||
|
if matched_idx is None:
|
||||||
|
return None
|
||||||
|
matched = pending_tool_call_chunks.pop(matched_idx)
|
||||||
|
candidate = matched.get("id")
|
||||||
|
if isinstance(candidate, str) and candidate:
|
||||||
|
if run_id:
|
||||||
|
lc_tool_call_id_by_run[run_id] = candidate
|
||||||
|
return candidate
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def _stream_agent_events(
|
async def _stream_agent_events(
|
||||||
agent: Any,
|
agent: Any,
|
||||||
config: dict[str, Any],
|
config: dict[str, Any],
|
||||||
|
|
@ -538,10 +574,28 @@ async def _stream_agent_events(
|
||||||
# ``tool_call_chunks`` from ``on_chat_model_stream``, key them by
|
# ``tool_call_chunks`` from ``on_chat_model_stream``, key them by
|
||||||
# name, and pop the next unconsumed entry at ``on_tool_start``. The
|
# name, and pop the next unconsumed entry at ``on_tool_start``. The
|
||||||
# authoritative id is later filled in at ``on_tool_end`` from
|
# authoritative id is later filled in at ``on_tool_end`` from
|
||||||
# ``ToolMessage.tool_call_id``.
|
# ``ToolMessage.tool_call_id``. Under parity_v2 we ALSO short-circuit
|
||||||
|
# this list for chunks that already registered into ``index_to_meta``
|
||||||
|
# below — so this list is reserved for the parity_v2-OFF / unmatched
|
||||||
|
# fallback path only and never re-pops a chunk we already streamed.
|
||||||
pending_tool_call_chunks: list[dict[str, Any]] = []
|
pending_tool_call_chunks: list[dict[str, Any]] = []
|
||||||
lc_tool_call_id_by_run: dict[str, str] = {}
|
lc_tool_call_id_by_run: dict[str, str] = {}
|
||||||
|
|
||||||
|
# parity_v2 only: live tool-call argument streaming. ``index_to_meta``
|
||||||
|
# is keyed by the chunk's ``index`` field — LangChain
|
||||||
|
# ``ToolCallChunk``s for the same call share an index but only the
|
||||||
|
# first chunk carries id+name (subsequent ones are id=None,
|
||||||
|
# name=None, args="<delta>"). We register an index when both id and
|
||||||
|
# name are observed on a chunk (per ToolCallChunk semantics they
|
||||||
|
# arrive together on the first chunk), then route every later chunk
|
||||||
|
# at that index to the same ``ui_id`` as a ``tool-input-delta``.
|
||||||
|
# ``ui_tool_call_id_by_run`` maps LangGraph ``run_id`` to the
|
||||||
|
# ``ui_id`` used for that call's ``tool-input-start`` so the matching
|
||||||
|
# ``tool-output-available`` (emitted from ``on_tool_end``) lands on
|
||||||
|
# the same card.
|
||||||
|
index_to_meta: dict[int, dict[str, str]] = {}
|
||||||
|
ui_tool_call_id_by_run: dict[str, str] = {}
|
||||||
|
|
||||||
# Per-tool-end mutable cache for the LangChain tool_call_id resolved
|
# Per-tool-end mutable cache for the LangChain tool_call_id resolved
|
||||||
# at ``on_tool_end``. ``_emit_tool_output`` reads this so every
|
# at ``on_tool_end``. ``_emit_tool_output`` reads this so every
|
||||||
# ``format_tool_output_available`` call automatically carries the
|
# ``format_tool_output_available`` call automatically carries the
|
||||||
|
|
@ -587,13 +641,6 @@ async def _stream_agent_events(
|
||||||
continue
|
continue
|
||||||
parts = _extract_chunk_parts(chunk)
|
parts = _extract_chunk_parts(chunk)
|
||||||
|
|
||||||
# Accumulate any tool_call_chunks for best-effort
|
|
||||||
# correlation with ``on_tool_start`` below. We don't emit
|
|
||||||
# anything here; the matching is done at tool-start time.
|
|
||||||
if parity_v2 and parts["tool_call_chunks"]:
|
|
||||||
for tcc in parts["tool_call_chunks"]:
|
|
||||||
pending_tool_call_chunks.append(tcc)
|
|
||||||
|
|
||||||
reasoning_delta = parts["reasoning"]
|
reasoning_delta = parts["reasoning"]
|
||||||
text_delta = parts["text"]
|
text_delta = parts["text"]
|
||||||
|
|
||||||
|
|
@ -639,6 +686,71 @@ async def _stream_agent_events(
|
||||||
yield streaming_service.format_text_delta(current_text_id, text_delta)
|
yield streaming_service.format_text_delta(current_text_id, text_delta)
|
||||||
accumulated_text += text_delta
|
accumulated_text += text_delta
|
||||||
|
|
||||||
|
# Live tool-call argument streaming. Runs AFTER text/reasoning
|
||||||
|
# processing so chunks containing both stay in their natural
|
||||||
|
# wire order (text → text-end → tool-input-start). Active
|
||||||
|
# text/reasoning are closed inside the registration branch
|
||||||
|
# before ``tool-input-start`` so the frontend sees a clean
|
||||||
|
# part boundary even when providers interleave.
|
||||||
|
if parity_v2 and parts["tool_call_chunks"]:
|
||||||
|
for tcc in parts["tool_call_chunks"]:
|
||||||
|
idx = tcc.get("index")
|
||||||
|
|
||||||
|
# Register this index when we first see id+name
|
||||||
|
# TOGETHER. Per LangChain ToolCallChunk semantics the
|
||||||
|
# first chunk for a tool call carries both fields
|
||||||
|
# together; later chunks have id=None, name=None and
|
||||||
|
# only ``args``. Requiring BOTH keeps wire
|
||||||
|
# ``tool-input-start`` always carrying a real
|
||||||
|
# toolName (assistant-ui's typed tool-part dispatch
|
||||||
|
# keys off it).
|
||||||
|
if idx is not None and idx not in index_to_meta:
|
||||||
|
lc_id = tcc.get("id")
|
||||||
|
name = tcc.get("name")
|
||||||
|
if lc_id and name:
|
||||||
|
ui_id = lc_id
|
||||||
|
|
||||||
|
# Close active text/reasoning so wire
|
||||||
|
# ordering stays clean even on providers
|
||||||
|
# that interleave text and tool-call chunks
|
||||||
|
# within the same stream window.
|
||||||
|
if current_text_id is not None:
|
||||||
|
yield streaming_service.format_text_end(current_text_id)
|
||||||
|
current_text_id = None
|
||||||
|
if current_reasoning_id is not None:
|
||||||
|
yield streaming_service.format_reasoning_end(
|
||||||
|
current_reasoning_id
|
||||||
|
)
|
||||||
|
current_reasoning_id = None
|
||||||
|
|
||||||
|
index_to_meta[idx] = {
|
||||||
|
"ui_id": ui_id,
|
||||||
|
"lc_id": lc_id,
|
||||||
|
"name": name,
|
||||||
|
}
|
||||||
|
yield streaming_service.format_tool_input_start(
|
||||||
|
ui_id,
|
||||||
|
name,
|
||||||
|
langchain_tool_call_id=lc_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit args delta for any chunk at a registered
|
||||||
|
# index (including idless continuations). Once an
|
||||||
|
# index is owned by ``index_to_meta`` we DO NOT
|
||||||
|
# append to ``pending_tool_call_chunks`` — that list
|
||||||
|
# is reserved for the parity_v2-OFF / unmatched
|
||||||
|
# fallback path so it never re-pops chunks already
|
||||||
|
# consumed here (skip-append).
|
||||||
|
meta = index_to_meta.get(idx) if idx is not None else None
|
||||||
|
if meta:
|
||||||
|
args_chunk = tcc.get("args") or ""
|
||||||
|
if args_chunk:
|
||||||
|
yield streaming_service.format_tool_input_delta(
|
||||||
|
meta["ui_id"], args_chunk
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pending_tool_call_chunks.append(tcc)
|
||||||
|
|
||||||
elif event_type == "on_tool_start":
|
elif event_type == "on_tool_start":
|
||||||
active_tool_depth += 1
|
active_tool_depth += 1
|
||||||
tool_name = event.get("name", "unknown_tool")
|
tool_name = event.get("name", "unknown_tool")
|
||||||
|
|
@ -969,44 +1081,65 @@ async def _stream_agent_events(
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call_id = (
|
# Resolve the card identity. If the chunk-emission loop
|
||||||
f"call_{run_id[:32]}"
|
# already registered an ``index`` for this tool call (parity_v2
|
||||||
if run_id
|
# path), reuse the same ui_id so the card sees:
|
||||||
else streaming_service.generate_tool_call_id()
|
# tool-input-start → deltas… → tool-input-available →
|
||||||
)
|
# tool-output-available all keyed by lc_id. Otherwise fall
|
||||||
|
# back to the synthetic ``call_<run_id>`` id and the legacy
|
||||||
# Best-effort attach the LangChain ``tool_call_id``. We
|
# best-effort match against ``pending_tool_call_chunks``.
|
||||||
# pop the first chunk in ``pending_tool_call_chunks`` whose
|
matched_meta: dict[str, str] | None = None
|
||||||
# name matches; if none match (the chunked args may not yet
|
if parity_v2:
|
||||||
# carry a ``name`` field, or the model skipped the chunked
|
# FIFO over indices 0,1,2…; first unassigned same-name
|
||||||
# form) we leave ``langchainToolCallId`` unset for now and
|
# match wins. Handles parallel same-name calls (e.g. two
|
||||||
# fill it in authoritatively at ``on_tool_end`` from
|
# write_file calls) deterministically as long as the
|
||||||
# ``ToolMessage.tool_call_id``.
|
# model interleaves on_tool_start in the same order it
|
||||||
langchain_tool_call_id: str | None = None
|
# streamed the args.
|
||||||
if parity_v2 and pending_tool_call_chunks:
|
taken_ui_ids = set(ui_tool_call_id_by_run.values())
|
||||||
matched_idx: int | None = None
|
for meta in index_to_meta.values():
|
||||||
for idx, tcc in enumerate(pending_tool_call_chunks):
|
if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids:
|
||||||
if tcc.get("name") == tool_name and tcc.get("id"):
|
matched_meta = meta
|
||||||
matched_idx = idx
|
|
||||||
break
|
break
|
||||||
if matched_idx is None:
|
|
||||||
for idx, tcc in enumerate(pending_tool_call_chunks):
|
|
||||||
if tcc.get("id"):
|
|
||||||
matched_idx = idx
|
|
||||||
break
|
|
||||||
if matched_idx is not None:
|
|
||||||
matched = pending_tool_call_chunks.pop(matched_idx)
|
|
||||||
candidate = matched.get("id")
|
|
||||||
if isinstance(candidate, str) and candidate:
|
|
||||||
langchain_tool_call_id = candidate
|
|
||||||
if run_id:
|
|
||||||
lc_tool_call_id_by_run[run_id] = candidate
|
|
||||||
|
|
||||||
yield streaming_service.format_tool_input_start(
|
tool_call_id: str
|
||||||
tool_call_id,
|
langchain_tool_call_id: str | None = None
|
||||||
tool_name,
|
if matched_meta is not None:
|
||||||
langchain_tool_call_id=langchain_tool_call_id,
|
tool_call_id = matched_meta["ui_id"]
|
||||||
)
|
langchain_tool_call_id = matched_meta["lc_id"]
|
||||||
|
# ``tool-input-start`` already fired during chunk
|
||||||
|
# emission — skip the duplicate. No pruning is needed
|
||||||
|
# because the chunk-emission loop intentionally never
|
||||||
|
# appends registered-index chunks to
|
||||||
|
# ``pending_tool_call_chunks`` (skip-append).
|
||||||
|
if run_id:
|
||||||
|
lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"]
|
||||||
|
else:
|
||||||
|
tool_call_id = (
|
||||||
|
f"call_{run_id[:32]}"
|
||||||
|
if run_id
|
||||||
|
else streaming_service.generate_tool_call_id()
|
||||||
|
)
|
||||||
|
# Legacy fallback: parity_v2 OFF, or parity_v2 ON but the
|
||||||
|
# provider didn't stream tool_call_chunks for this call
|
||||||
|
# (no index registered). Run the existing best-effort
|
||||||
|
# match BEFORE emitting start so we still attach an
|
||||||
|
# authoritative ``langchainToolCallId`` when possible.
|
||||||
|
if parity_v2:
|
||||||
|
langchain_tool_call_id = _legacy_match_lc_id(
|
||||||
|
pending_tool_call_chunks,
|
||||||
|
tool_name,
|
||||||
|
run_id,
|
||||||
|
lc_tool_call_id_by_run,
|
||||||
|
)
|
||||||
|
yield streaming_service.format_tool_input_start(
|
||||||
|
tool_call_id,
|
||||||
|
tool_name,
|
||||||
|
langchain_tool_call_id=langchain_tool_call_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if run_id:
|
||||||
|
ui_tool_call_id_by_run[run_id] = tool_call_id
|
||||||
|
|
||||||
# Sanitize tool_input: strip runtime-injected non-serializable
|
# Sanitize tool_input: strip runtime-injected non-serializable
|
||||||
# values (e.g. LangChain ToolRuntime) before sending over SSE.
|
# values (e.g. LangChain ToolRuntime) before sending over SSE.
|
||||||
if isinstance(tool_input, dict):
|
if isinstance(tool_input, dict):
|
||||||
|
|
@ -1059,7 +1192,15 @@ async def _stream_agent_events(
|
||||||
result.write_succeeded = True
|
result.write_succeeded = True
|
||||||
result.verification_succeeded = True
|
result.verification_succeeded = True
|
||||||
|
|
||||||
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
|
# Look up the SAME card id used at on_tool_start (either the
|
||||||
|
# parity_v2 lc-id-derived ui_id or the legacy synthetic
|
||||||
|
# ``call_<run_id>``) so the output event always lands on the
|
||||||
|
# same card as start/delta/available. Fallback preserves the
|
||||||
|
# legacy synthetic shape for parity_v2-OFF / unknown-run paths.
|
||||||
|
tool_call_id = ui_tool_call_id_by_run.get(
|
||||||
|
run_id,
|
||||||
|
f"call_{run_id[:32]}" if run_id else "call_unknown",
|
||||||
|
)
|
||||||
original_step_id = tool_step_ids.get(
|
original_step_id = tool_step_ids.get(
|
||||||
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
|
run_id, f"{step_prefix}-unknown-{run_id[:8]}"
|
||||||
)
|
)
|
||||||
|
|
@ -1070,17 +1211,22 @@ async def _stream_agent_events(
|
||||||
# at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``)
|
# at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``)
|
||||||
# if the output isn't a ToolMessage. The value is stored in
|
# if the output isn't a ToolMessage. The value is stored in
|
||||||
# ``current_lc_tool_call_id`` so ``_emit_tool_output``
|
# ``current_lc_tool_call_id`` so ``_emit_tool_output``
|
||||||
# picks it up for every output emit below. Stays None when
|
# picks it up for every output emit below.
|
||||||
# parity_v2 is off so legacy emit paths are untouched.
|
#
|
||||||
|
# Emitted in BOTH parity_v2 and legacy modes: the chat tool
|
||||||
|
# card needs the LangChain id to match against the
|
||||||
|
# ``data-action-log`` SSE event (keyed by ``lc_tool_call_id``)
|
||||||
|
# so the inline Revert button can light up. Reading
|
||||||
|
# ``raw_output.tool_call_id`` is a cheap, non-mutating attribute
|
||||||
|
# access that is safe regardless of feature-flag state.
|
||||||
current_lc_tool_call_id["value"] = None
|
current_lc_tool_call_id["value"] = None
|
||||||
if parity_v2:
|
authoritative = getattr(raw_output, "tool_call_id", None)
|
||||||
authoritative = getattr(raw_output, "tool_call_id", None)
|
if isinstance(authoritative, str) and authoritative:
|
||||||
if isinstance(authoritative, str) and authoritative:
|
current_lc_tool_call_id["value"] = authoritative
|
||||||
current_lc_tool_call_id["value"] = authoritative
|
if run_id:
|
||||||
if run_id:
|
lc_tool_call_id_by_run[run_id] = authoritative
|
||||||
lc_tool_call_id_by_run[run_id] = authoritative
|
elif run_id and run_id in lc_tool_call_id_by_run:
|
||||||
elif run_id and run_id in lc_tool_call_id_by_run:
|
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
|
||||||
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
|
|
||||||
|
|
||||||
if tool_name == "read_file":
|
if tool_name == "read_file":
|
||||||
yield streaming_service.format_thinking_step(
|
yield streaming_service.format_thinking_step(
|
||||||
|
|
|
||||||
|
|
@ -183,3 +183,46 @@ class TestDefensive:
|
||||||
assert out["text"] == ""
|
assert out["text"] == ""
|
||||||
assert out["reasoning"] == ""
|
assert out["reasoning"] == ""
|
||||||
assert out["tool_call_chunks"] == []
|
assert out["tool_call_chunks"] == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestIdlessContinuationChunks:
|
||||||
|
"""Per LangChain ``ToolCallChunk`` semantics, the FIRST chunk for a
|
||||||
|
tool call carries id+name; later chunks for the same call have
|
||||||
|
``id=None, name=None`` and only ``args`` + ``index``. Live tool-call
|
||||||
|
argument streaming relies on those idless continuation chunks
|
||||||
|
flowing through ``_extract_chunk_parts`` UNTOUCHED so the upstream
|
||||||
|
chunk-emission loop can still route them by ``index``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_idless_continuation_chunk_preserved_verbatim(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert len(out["tool_call_chunks"]) == 1
|
||||||
|
tcc = out["tool_call_chunks"][0]
|
||||||
|
assert tcc.get("id") is None
|
||||||
|
assert tcc.get("name") is None
|
||||||
|
assert tcc.get("args") == '_path":"/x"}'
|
||||||
|
assert tcc.get("index") == 0
|
||||||
|
|
||||||
|
def test_first_then_idless_sequence_preserves_index(self) -> None:
|
||||||
|
"""Both chunks for the same call share an ``index`` key — the
|
||||||
|
index-routing loop in ``stream_new_chat`` depends on it."""
|
||||||
|
first = _FakeChunk(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cont = _FakeChunk(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out_first = _extract_chunk_parts(first)
|
||||||
|
out_cont = _extract_chunk_parts(cont)
|
||||||
|
assert out_first["tool_call_chunks"][0]["index"] == 0
|
||||||
|
assert out_cont["tool_call_chunks"][0]["index"] == 0
|
||||||
|
assert out_cont["tool_call_chunks"][0].get("id") is None
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,527 @@
|
||||||
|
"""Unit tests for live tool-call argument streaming.
|
||||||
|
|
||||||
|
Pins the wire format that ``_stream_agent_events`` emits when
|
||||||
|
``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start`` →
|
||||||
|
``tool-input-delta``... → ``tool-input-available`` → ``tool-output-available``
|
||||||
|
all keyed by the same LangChain ``tool_call.id``.
|
||||||
|
|
||||||
|
Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and
|
||||||
|
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to
|
||||||
|
``_stream_agent_events`` so we exercise them via the public wire output.
|
||||||
|
|
||||||
|
These tests also lock in the legacy / parity_v2-OFF behaviour so the
|
||||||
|
synthetic ``call_<run_id>`` shape stays stable for older clients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import app.tasks.chat.stream_new_chat as stream_module
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
|
from app.tasks.chat.stream_new_chat import (
|
||||||
|
StreamResult,
|
||||||
|
_legacy_match_lc_id,
|
||||||
|
_stream_agent_events,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeChunk:
|
||||||
|
"""Minimal stand-in for ``AIMessageChunk``."""
|
||||||
|
|
||||||
|
content: Any = ""
|
||||||
|
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
|
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeToolMessage:
|
||||||
|
"""Stand-in for ``ToolMessage`` returned by ``on_tool_end``."""
|
||||||
|
|
||||||
|
content: Any
|
||||||
|
tool_call_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeAgentState:
|
||||||
|
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# Empty values keeps the cloud-fallback safety-net branch a no-op,
|
||||||
|
# and an empty ``tasks`` list keeps the post-stream interrupt
|
||||||
|
# check a no-op too.
|
||||||
|
self.values: dict[str, Any] = {}
|
||||||
|
self.tasks: list[Any] = []
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeAgent:
|
||||||
|
"""Replays a list of ``astream_events`` events."""
|
||||||
|
|
||||||
|
def __init__(self, events: list[dict[str, Any]]) -> None:
|
||||||
|
self._events = events
|
||||||
|
|
||||||
|
async def astream_events( # type: ignore[no-untyped-def]
|
||||||
|
self, _input_data: Any, *, config: dict[str, Any], version: str
|
||||||
|
) -> AsyncGenerator[dict[str, Any], None]:
|
||||||
|
del config, version # unused, contract-compatible
|
||||||
|
for ev in self._events:
|
||||||
|
yield ev
|
||||||
|
|
||||||
|
async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState:
|
||||||
|
# Called once after astream_events drains so the cloud-fallback
|
||||||
|
# safety net can inspect staged filesystem work. The fake stays
|
||||||
|
# empty so the safety net is a no-op.
|
||||||
|
return _FakeAgentState()
|
||||||
|
|
||||||
|
|
||||||
|
def _model_stream(
|
||||||
|
*,
|
||||||
|
text: str = "",
|
||||||
|
reasoning: str = "",
|
||||||
|
tool_call_chunks: list[dict[str, Any]] | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
"event": "on_chat_model_stream",
|
||||||
|
"tags": tags or [],
|
||||||
|
"data": {
|
||||||
|
"chunk": _FakeChunk(
|
||||||
|
content=text,
|
||||||
|
tool_call_chunks=list(tool_call_chunks or []),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
# reasoning piggybacks via additional_kwargs path; if needed,
|
||||||
|
# override content to a typed-block list. Most tests just check
|
||||||
|
# tool_call_chunks routing so this is fine.
|
||||||
|
}
|
||||||
|
if not reasoning
|
||||||
|
else {
|
||||||
|
"event": "on_chat_model_stream",
|
||||||
|
"tags": tags or [],
|
||||||
|
"data": {
|
||||||
|
"chunk": _FakeChunk(
|
||||||
|
content=text,
|
||||||
|
additional_kwargs={"reasoning_content": reasoning},
|
||||||
|
tool_call_chunks=list(tool_call_chunks or []),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_start(
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
run_id: str,
|
||||||
|
input_payload: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"event": "on_tool_start",
|
||||||
|
"name": name,
|
||||||
|
"run_id": run_id,
|
||||||
|
"data": {"input": input_payload or {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_end(
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
run_id: str,
|
||||||
|
tool_call_id: str | None = None,
|
||||||
|
output: Any = "ok",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"event": "on_tool_end",
|
||||||
|
"name": name,
|
||||||
|
"run_id": run_id,
|
||||||
|
"data": {
|
||||||
|
"output": _FakeToolMessage(
|
||||||
|
content=json.dumps(output) if not isinstance(output, str) else output,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
stream_module,
|
||||||
|
"get_flags",
|
||||||
|
lambda: AgentFeatureFlags(enable_stream_parity_v2=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
stream_module,
|
||||||
|
"get_flags",
|
||||||
|
lambda: AgentFeatureFlags(enable_stream_parity_v2=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Run ``_stream_agent_events`` against a fake agent and return the
|
||||||
|
SSE payloads (parsed JSON) it yielded.
|
||||||
|
"""
|
||||||
|
agent = _FakeAgent(events)
|
||||||
|
service = VercelStreamingService()
|
||||||
|
result = StreamResult()
|
||||||
|
config = {"configurable": {"thread_id": "test-thread"}}
|
||||||
|
sse_lines: list[str] = []
|
||||||
|
async for sse in _stream_agent_events(
|
||||||
|
agent, config, {}, service, result, step_prefix="thinking"
|
||||||
|
):
|
||||||
|
sse_lines.append(sse)
|
||||||
|
|
||||||
|
parsed: list[dict[str, Any]] = []
|
||||||
|
for line in sse_lines:
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
body = line[len("data: ") :].rstrip("\n")
|
||||||
|
if not body or body == "[DONE]":
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
parsed.append(json.loads(body))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def _types(payloads: list[dict[str, Any]]) -> list[str]:
|
||||||
|
return [p.get("type", "?") for p in payloads]
|
||||||
|
|
||||||
|
|
||||||
|
def _of_type(payloads: list[dict[str, Any]], type_name: str) -> list[dict[str, Any]]:
|
||||||
|
return [p for p in payloads if p.get("type") == type_name]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper: ``_legacy_match_lc_id`` is a pure refactor; assert behaviour.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLegacyMatch:
|
||||||
|
def test_pops_first_id_bearing_chunk_with_matching_name(self) -> None:
|
||||||
|
chunks: list[dict[str, Any]] = [
|
||||||
|
{"id": "x1", "name": "ls"},
|
||||||
|
{"id": "y1", "name": "write_file"},
|
||||||
|
]
|
||||||
|
runs: dict[str, str] = {}
|
||||||
|
result = _legacy_match_lc_id(chunks, "write_file", "run-1", runs)
|
||||||
|
assert result == "y1"
|
||||||
|
assert chunks == [{"id": "x1", "name": "ls"}]
|
||||||
|
assert runs == {"run-1": "y1"}
|
||||||
|
|
||||||
|
def test_falls_back_to_any_id_bearing_when_name_mismatches(self) -> None:
|
||||||
|
chunks: list[dict[str, Any]] = [{"id": "anon", "name": None}]
|
||||||
|
runs: dict[str, str] = {}
|
||||||
|
out = _legacy_match_lc_id(chunks, "ls", "run-2", runs)
|
||||||
|
assert out == "anon"
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
def test_returns_none_when_no_id_bearing_chunk(self) -> None:
|
||||||
|
chunks: list[dict[str, Any]] = [{"id": None, "name": None}]
|
||||||
|
runs: dict[str, str] = {}
|
||||||
|
assert _legacy_match_lc_id(chunks, "ls", "run-3", runs) is None
|
||||||
|
assert chunks == [{"id": None, "name": None}]
|
||||||
|
assert runs == {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# parity_v2 wire format tests.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
|
||||||
|
"""First chunk carries id+name; later idless chunks at the same
|
||||||
|
``index`` merge into the SAME ``tool-input-start`` ui id and emit
|
||||||
|
one ``tool-input-delta`` per chunk."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
_tool_start(
|
||||||
|
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
|
||||||
|
),
|
||||||
|
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
|
||||||
|
]
|
||||||
|
|
||||||
|
payloads = await _drain(events)
|
||||||
|
|
||||||
|
starts = _of_type(payloads, "tool-input-start")
|
||||||
|
deltas = _of_type(payloads, "tool-input-delta")
|
||||||
|
available = _of_type(payloads, "tool-input-available")
|
||||||
|
output = _of_type(payloads, "tool-output-available")
|
||||||
|
|
||||||
|
assert len(starts) == 1
|
||||||
|
assert starts[0]["toolCallId"] == "lc-1"
|
||||||
|
assert starts[0]["toolName"] == "write_file"
|
||||||
|
assert starts[0]["langchainToolCallId"] == "lc-1"
|
||||||
|
|
||||||
|
assert [d["inputTextDelta"] for d in deltas] == ['{"file', '_path":"/x"}']
|
||||||
|
assert all(d["toolCallId"] == "lc-1" for d in deltas)
|
||||||
|
|
||||||
|
assert len(available) == 1
|
||||||
|
assert available[0]["toolCallId"] == "lc-1"
|
||||||
|
|
||||||
|
assert len(output) == 1
|
||||||
|
assert output[0]["toolCallId"] == "lc-1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_two_interleaved_tool_calls_route_by_index(
|
||||||
|
parity_v2_on: None,
|
||||||
|
) -> None:
|
||||||
|
"""Two same-name calls with distinct indices keep their deltas
|
||||||
|
routed to the right card."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-A", "name": "write_file", "args": '{"a":1', "index": 0},
|
||||||
|
{"id": "lc-B", "name": "write_file", "args": '{"b":2', "index": 1},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": None, "name": None, "args": "}", "index": 0},
|
||||||
|
{"id": None, "name": None, "args": "}", "index": 1},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_start(name="write_file", run_id="run-A", input_payload={"a": 1}),
|
||||||
|
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-A"),
|
||||||
|
_tool_start(name="write_file", run_id="run-B", input_payload={"b": 2}),
|
||||||
|
_tool_end(name="write_file", run_id="run-B", tool_call_id="lc-B"),
|
||||||
|
]
|
||||||
|
|
||||||
|
payloads = await _drain(events)
|
||||||
|
|
||||||
|
starts = _of_type(payloads, "tool-input-start")
|
||||||
|
deltas = _of_type(payloads, "tool-input-delta")
|
||||||
|
output = _of_type(payloads, "tool-output-available")
|
||||||
|
|
||||||
|
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
|
||||||
|
|
||||||
|
by_id: dict[str, list[str]] = {"lc-A": [], "lc-B": []}
|
||||||
|
for d in deltas:
|
||||||
|
by_id[d["toolCallId"]].append(d["inputTextDelta"])
|
||||||
|
assert by_id["lc-A"] == ['{"a":1', "}"]
|
||||||
|
assert by_id["lc-B"] == ['{"b":2', "}"]
|
||||||
|
|
||||||
|
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
|
||||||
|
"""Whatever id ``tool-input-start`` chose must be the SAME id used
|
||||||
|
on ``tool-input-available`` AND ``tool-output-available``."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-9", "name": "ls", "args": '{"path":"/"}', "index": 0}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_start(name="ls", run_id="run-X", input_payload={"path": "/"}),
|
||||||
|
_tool_end(name="ls", run_id="run-X", tool_call_id="lc-9"),
|
||||||
|
]
|
||||||
|
payloads = await _drain(events)
|
||||||
|
relevant = [
|
||||||
|
p
|
||||||
|
for p in payloads
|
||||||
|
if p.get("type")
|
||||||
|
in {"tool-input-start", "tool-input-available", "tool-output-available"}
|
||||||
|
]
|
||||||
|
assert {p["toolCallId"] for p in relevant} == {"lc-9"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
|
||||||
|
"""When the chunk-emission loop already fired ``tool-input-start``
|
||||||
|
for this run, ``on_tool_start`` MUST NOT emit a second one."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_start(name="write_file", run_id="run-A", input_payload={}),
|
||||||
|
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
|
||||||
|
]
|
||||||
|
payloads = await _drain(events)
|
||||||
|
starts = _of_type(payloads, "tool-input-start")
|
||||||
|
assert len(starts) == 1
|
||||||
|
assert starts[0]["toolCallId"] == "lc-1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_active_text_closes_before_early_tool_input_start(
|
||||||
|
parity_v2_on: None,
|
||||||
|
) -> None:
|
||||||
|
"""Streaming a text-delta then a tool-call chunk in subsequent
|
||||||
|
chunks: the wire MUST contain ``text-end`` before the FIRST
|
||||||
|
``tool-input-start`` (clean part boundary on the frontend)."""
|
||||||
|
events = [
|
||||||
|
_model_stream(text="Working on it"),
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_start(name="write_file", run_id="run-A", input_payload={}),
|
||||||
|
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
|
||||||
|
]
|
||||||
|
types = _types(await _drain(events))
|
||||||
|
text_end_idx = types.index("text-end")
|
||||||
|
start_idx = types.index("tool-input-start")
|
||||||
|
assert text_end_idx < start_idx
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mixed_text_and_tool_chunk_preserve_order(
|
||||||
|
parity_v2_on: None,
|
||||||
|
) -> None:
|
||||||
|
"""One AIMessageChunk that carries BOTH ``text`` content AND
|
||||||
|
``tool_call_chunks`` should emit the text delta FIRST, then close
|
||||||
|
text, then ``tool-input-start``+``tool-input-delta``."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
text="I'll update it",
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"id": "lc-1",
|
||||||
|
"name": "write_file",
|
||||||
|
"args": '{"file_path":"/x"}',
|
||||||
|
"index": 0,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
_tool_start(
|
||||||
|
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
|
||||||
|
),
|
||||||
|
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
|
||||||
|
]
|
||||||
|
types = _types(await _drain(events))
|
||||||
|
# text-start … text-delta … text-end … tool-input-start … tool-input-delta
|
||||||
|
assert types.index("text-start") < types.index("text-delta")
|
||||||
|
assert types.index("text-delta") < types.index("text-end")
|
||||||
|
assert types.index("text-end") < types.index("tool-input-start")
|
||||||
|
assert types.index("tool-input-start") < types.index("tool-input-delta")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parity_v2_off_preserves_legacy_shape(
|
||||||
|
parity_v2_off: None,
|
||||||
|
) -> None:
|
||||||
|
"""When the flag is OFF, no deltas are emitted and the ``toolCallId``
|
||||||
|
is ``call_<run_id>`` (NOT the lc id)."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}),
|
||||||
|
_tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"),
|
||||||
|
]
|
||||||
|
payloads = await _drain(events)
|
||||||
|
|
||||||
|
assert _of_type(payloads, "tool-input-delta") == []
|
||||||
|
starts = _of_type(payloads, "tool-input-start")
|
||||||
|
assert len(starts) == 1
|
||||||
|
assert starts[0]["toolCallId"].startswith("call_run-A")
|
||||||
|
# No ``langchainToolCallId`` propagation on ``tool-input-start`` in
|
||||||
|
# legacy mode (the start event fires before the ToolMessage is
|
||||||
|
# available, so we can't extract the authoritative LangChain id yet).
|
||||||
|
assert "langchainToolCallId" not in starts[0]
|
||||||
|
output = _of_type(payloads, "tool-output-available")
|
||||||
|
assert output[0]["toolCallId"].startswith("call_run-A")
|
||||||
|
# ``tool-output-available`` MUST carry ``langchainToolCallId`` even
|
||||||
|
# in legacy mode: the chat tool card uses it to backfill the
|
||||||
|
# LangChain id and join against the ``data-action-log`` SSE event
|
||||||
|
# (keyed by ``lc_tool_call_id``) so the inline Revert button can
|
||||||
|
# light up. Sourced from the returned ``ToolMessage.tool_call_id``,
|
||||||
|
# which is populated regardless of feature-flag state.
|
||||||
|
assert output[0]["langchainToolCallId"] == "lc-1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skip_append_prevents_stale_id_reuse(
|
||||||
|
parity_v2_on: None,
|
||||||
|
) -> None:
|
||||||
|
"""Two same-name tools: the SECOND tool's ``langchainToolCallId``
|
||||||
|
must NOT come from the first tool's chunk (``pending_tool_call_chunks``
|
||||||
|
must stay empty for indexed-registered chunks)."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-A", "name": "write_file", "args": "{}", "index": 0},
|
||||||
|
{"id": "lc-B", "name": "write_file", "args": "{}", "index": 1},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_start(name="write_file", run_id="run-1", input_payload={}),
|
||||||
|
_tool_end(name="write_file", run_id="run-1", tool_call_id="lc-A"),
|
||||||
|
_tool_start(name="write_file", run_id="run-2", input_payload={}),
|
||||||
|
_tool_end(name="write_file", run_id="run-2", tool_call_id="lc-B"),
|
||||||
|
]
|
||||||
|
payloads = await _drain(events)
|
||||||
|
|
||||||
|
starts = _of_type(payloads, "tool-input-start")
|
||||||
|
# Two distinct lc ids, each its own card.
|
||||||
|
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
|
||||||
|
# Each tool-output-available landed on its respective card.
|
||||||
|
output = _of_type(payloads, "tool-output-available")
|
||||||
|
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_registration_waits_for_both_id_and_name(
|
||||||
|
parity_v2_on: None,
|
||||||
|
) -> None:
|
||||||
|
"""An id-only chunk (no name yet) must NOT emit ``tool-input-start``."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[{"id": "lc-1", "name": None, "args": "", "index": 0}]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
payloads = await _drain(events)
|
||||||
|
assert _of_type(payloads, "tool-input-start") == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unmatched_fallback_still_attaches_lc_id(
|
||||||
|
parity_v2_on: None,
|
||||||
|
) -> None:
|
||||||
|
"""parity_v2 ON, but the provider didn't include an ``index``: the
|
||||||
|
legacy fallback path must still emit ``tool-input-start`` with the
|
||||||
|
matching ``langchainToolCallId``."""
|
||||||
|
events = [
|
||||||
|
# No index on the chunk → not registered into index_to_meta;
|
||||||
|
# falls through to ``pending_tool_call_chunks`` so the legacy
|
||||||
|
# match path can pop it at on_tool_start.
|
||||||
|
_model_stream(tool_call_chunks=[{"id": "lc-orphan", "name": "ls", "args": ""}]),
|
||||||
|
_tool_start(name="ls", run_id="run-1", input_payload={"path": "/"}),
|
||||||
|
_tool_end(name="ls", run_id="run-1", tool_call_id="lc-orphan"),
|
||||||
|
]
|
||||||
|
payloads = await _drain(events)
|
||||||
|
starts = _of_type(payloads, "tool-input-start")
|
||||||
|
assert len(starts) == 1
|
||||||
|
assert starts[0]["toolCallId"].startswith("call_run-1")
|
||||||
|
assert starts[0]["langchainToolCallId"] == "lc-orphan"
|
||||||
|
|
@ -14,13 +14,6 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms";
|
import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms";
|
||||||
import {
|
|
||||||
agentActionsByChatTurnIdAtom,
|
|
||||||
markAgentActionRevertedAtom,
|
|
||||||
resetAgentActionMapAtom,
|
|
||||||
updateAgentActionReversibleAtom,
|
|
||||||
upsertAgentActionAtom,
|
|
||||||
} from "@/atoms/chat/agent-actions.atom";
|
|
||||||
import {
|
import {
|
||||||
clearTargetCommentIdAtom,
|
clearTargetCommentIdAtom,
|
||||||
currentThreadAtom,
|
currentThreadAtom,
|
||||||
|
|
@ -56,6 +49,12 @@ import {
|
||||||
type TokenUsageData,
|
type TokenUsageData,
|
||||||
TokenUsageProvider,
|
TokenUsageProvider,
|
||||||
} from "@/components/assistant-ui/token-usage-context";
|
} from "@/components/assistant-ui/token-usage-context";
|
||||||
|
import {
|
||||||
|
applyActionLogSse,
|
||||||
|
applyActionLogUpdatedSse,
|
||||||
|
markActionRevertedInCache,
|
||||||
|
useAgentActionsQuery,
|
||||||
|
} from "@/hooks/use-agent-actions-query";
|
||||||
import { useChatSessionStateSync } from "@/hooks/use-chat-session-state";
|
import { useChatSessionStateSync } from "@/hooks/use-chat-session-state";
|
||||||
import { useMessagesSync } from "@/hooks/use-messages-sync";
|
import { useMessagesSync } from "@/hooks/use-messages-sync";
|
||||||
import { getAgentFilesystemSelection } from "@/lib/agent-filesystem";
|
import { getAgentFilesystemSelection } from "@/lib/agent-filesystem";
|
||||||
|
|
@ -76,12 +75,12 @@ import {
|
||||||
addToolCall,
|
addToolCall,
|
||||||
appendReasoning,
|
appendReasoning,
|
||||||
appendText,
|
appendText,
|
||||||
|
appendToolInputDelta,
|
||||||
buildContentForPersistence,
|
buildContentForPersistence,
|
||||||
buildContentForUI,
|
buildContentForUI,
|
||||||
type ContentPartsState,
|
type ContentPartsState,
|
||||||
endReasoning,
|
endReasoning,
|
||||||
FrameBatchedUpdater,
|
FrameBatchedUpdater,
|
||||||
findToolCallIdByLcId,
|
|
||||||
readSSEStream,
|
readSSEStream,
|
||||||
type SSEEvent,
|
type SSEEvent,
|
||||||
type ThinkingStepData,
|
type ThinkingStepData,
|
||||||
|
|
@ -511,14 +510,6 @@ export default function NewChatPage() {
|
||||||
const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom);
|
const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom);
|
||||||
const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom);
|
const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom);
|
||||||
const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom);
|
const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom);
|
||||||
// Agent action log SSE side-channel.
|
|
||||||
const upsertAgentAction = useSetAtom(upsertAgentActionAtom);
|
|
||||||
const updateAgentActionReversible = useSetAtom(updateAgentActionReversibleAtom);
|
|
||||||
const markAgentActionReverted = useSetAtom(markAgentActionRevertedAtom);
|
|
||||||
const resetAgentActionMap = useSetAtom(resetAgentActionMapAtom);
|
|
||||||
// Chat-turn-keyed action map for the edit-from-position pre-flight
|
|
||||||
// that decides whether to show the confirmation dialog.
|
|
||||||
const agentActionsByChatTurnId = useAtomValue(agentActionsByChatTurnIdAtom);
|
|
||||||
// Edit dialog state. Holds the message id being edited and
|
// Edit dialog state. Holds the message id being edited and
|
||||||
// the (already extracted) regenerate args so we can resume the edit
|
// the (already extracted) regenerate args so we can resume the edit
|
||||||
// after the user picks "revert all" / "continue" / "cancel".
|
// after the user picks "revert all" / "continue" / "cancel".
|
||||||
|
|
@ -547,6 +538,11 @@ export default function NewChatPage() {
|
||||||
content: unknown;
|
content: unknown;
|
||||||
author_id: string | null;
|
author_id: string | null;
|
||||||
created_at: string;
|
created_at: string;
|
||||||
|
// Forwarded so ``convertToThreadMessage`` can rebuild the
|
||||||
|
// ``metadata.custom.chatTurnId`` on the
|
||||||
|
// ``ThreadMessageLike``. Required by the inline Revert
|
||||||
|
// button's per-turn fallback.
|
||||||
|
turn_id?: string | null;
|
||||||
}[]
|
}[]
|
||||||
) => {
|
) => {
|
||||||
if (isRunning) {
|
if (isRunning) {
|
||||||
|
|
@ -579,6 +575,11 @@ export default function NewChatPage() {
|
||||||
created_at: msg.created_at,
|
created_at: msg.created_at,
|
||||||
author_display_name: member?.user_display_name ?? existingAuthor?.displayName ?? null,
|
author_display_name: member?.user_display_name ?? existingAuthor?.displayName ?? null,
|
||||||
author_avatar_url: member?.user_avatar_url ?? existingAuthor?.avatarUrl ?? null,
|
author_avatar_url: member?.user_avatar_url ?? existingAuthor?.avatarUrl ?? null,
|
||||||
|
// Forward the per-turn correlation id so the
|
||||||
|
// inline Revert button's ``(chat_turn_id,
|
||||||
|
// tool_name, position)`` fallback survives the
|
||||||
|
// post-stream Zero re-sync.
|
||||||
|
turn_id: msg.turn_id ?? null,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -595,6 +596,13 @@ export default function NewChatPage() {
|
||||||
return Number.isNaN(parsed) ? 0 : parsed;
|
return Number.isNaN(parsed) ? 0 : parsed;
|
||||||
}, [params.search_space_id]);
|
}, [params.search_space_id]);
|
||||||
|
|
||||||
|
// Unified store for agent-action rows (the same react-query cache
|
||||||
|
// the agent-actions sheet, the inline Revert button, and the
|
||||||
|
// per-turn Revert button all read). Hydrates from
|
||||||
|
// ``GET /threads/{id}/actions`` and is updated incrementally by the
|
||||||
|
// SSE handlers + revert-batch results below — no atom side-channel.
|
||||||
|
const { items: agentActionItems } = useAgentActionsQuery(threadId);
|
||||||
|
|
||||||
// Extract chat_id from URL params
|
// Extract chat_id from URL params
|
||||||
const urlChatId = useMemo(() => {
|
const urlChatId = useMemo(() => {
|
||||||
const id = params.chat_id;
|
const id = params.chat_id;
|
||||||
|
|
@ -738,7 +746,8 @@ export default function NewChatPage() {
|
||||||
clearPlanOwnerRegistry();
|
clearPlanOwnerRegistry();
|
||||||
closeReportPanel();
|
closeReportPanel();
|
||||||
closeEditorPanel();
|
closeEditorPanel();
|
||||||
resetAgentActionMap();
|
// Note: agent-action data is keyed by threadId in react-query so
|
||||||
|
// switching threads naturally swaps caches; no explicit reset.
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (urlChatId > 0) {
|
if (urlChatId > 0) {
|
||||||
|
|
@ -807,7 +816,6 @@ export default function NewChatPage() {
|
||||||
removeChatTab,
|
removeChatTab,
|
||||||
searchSpaceId,
|
searchSpaceId,
|
||||||
tokenUsageStore,
|
tokenUsageStore,
|
||||||
resetAgentActionMap,
|
|
||||||
]);
|
]);
|
||||||
|
|
||||||
// Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same)
|
// Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same)
|
||||||
|
|
@ -1145,6 +1153,15 @@ export default function NewChatPage() {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
const scheduleFlush = () => batcher.schedule(flushMessages);
|
const scheduleFlush = () => batcher.schedule(flushMessages);
|
||||||
|
// Force-flush helper: ``batcher.flush()`` is a no-op when
|
||||||
|
// ``dirty=false`` (e.g. a tool starts before any text
|
||||||
|
// streamed). ``scheduleFlush(); batcher.flush()`` sets
|
||||||
|
// the dirty bit FIRST so terminal events render
|
||||||
|
// promptly without the 50ms throttle delay.
|
||||||
|
const forceFlush = () => {
|
||||||
|
scheduleFlush();
|
||||||
|
batcher.flush();
|
||||||
|
};
|
||||||
|
|
||||||
for await (const parsed of readSSEStream(response)) {
|
for await (const parsed of readSSEStream(response)) {
|
||||||
switch (parsed.type) {
|
switch (parsed.type) {
|
||||||
|
|
@ -1181,13 +1198,23 @@ export default function NewChatPage() {
|
||||||
false,
|
false,
|
||||||
parsed.langchainToolCallId
|
parsed.langchainToolCallId
|
||||||
);
|
);
|
||||||
batcher.flush();
|
forceFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "tool-input-delta":
|
||||||
|
// High-frequency event: deltas can fire dozens
|
||||||
|
// of times per call, so use throttled
|
||||||
|
// scheduleFlush (NOT forceFlush) to coalesce.
|
||||||
|
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
|
||||||
|
scheduleFlush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "tool-input-available": {
|
case "tool-input-available": {
|
||||||
|
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
|
||||||
if (toolCallIndices.has(parsed.toolCallId)) {
|
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
args: parsed.input || {},
|
args: parsed.input || {},
|
||||||
|
argsText: finalArgsText,
|
||||||
langchainToolCallId: parsed.langchainToolCallId,
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -1200,8 +1227,14 @@ export default function NewChatPage() {
|
||||||
false,
|
false,
|
||||||
parsed.langchainToolCallId
|
parsed.langchainToolCallId
|
||||||
);
|
);
|
||||||
|
// addToolCall doesn't accept argsText today;
|
||||||
|
// backfill via updateToolCall so the new card
|
||||||
|
// renders pretty-printed JSON.
|
||||||
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
argsText: finalArgsText,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
batcher.flush();
|
forceFlush();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1220,7 +1253,7 @@ export default function NewChatPage() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
batcher.flush();
|
forceFlush();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1316,34 +1349,17 @@ export default function NewChatPage() {
|
||||||
}
|
}
|
||||||
|
|
||||||
case "data-action-log": {
|
case "data-action-log": {
|
||||||
const al = parsed.data;
|
applyActionLogSse(queryClient, currentThreadId, searchSpaceId, parsed.data);
|
||||||
const matchedToolCallId = al.lc_tool_call_id
|
|
||||||
? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id)
|
|
||||||
: null;
|
|
||||||
upsertAgentAction({
|
|
||||||
action: {
|
|
||||||
id: al.id,
|
|
||||||
threadId: currentThreadId,
|
|
||||||
lcToolCallId: al.lc_tool_call_id,
|
|
||||||
chatTurnId: al.chat_turn_id,
|
|
||||||
toolName: al.tool_name,
|
|
||||||
reversible: al.reversible,
|
|
||||||
reverseDescriptorPresent: al.reverse_descriptor_present,
|
|
||||||
error: al.error,
|
|
||||||
revertedByActionId: null,
|
|
||||||
isRevertAction: false,
|
|
||||||
createdAt: al.created_at,
|
|
||||||
},
|
|
||||||
toolCallId: matchedToolCallId,
|
|
||||||
});
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case "data-action-log-updated": {
|
case "data-action-log-updated": {
|
||||||
updateAgentActionReversible({
|
applyActionLogUpdatedSse(
|
||||||
id: parsed.data.id,
|
queryClient,
|
||||||
reversible: parsed.data.reversible,
|
currentThreadId,
|
||||||
});
|
parsed.data.id,
|
||||||
|
parsed.data.reversible
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1559,6 +1575,15 @@ export default function NewChatPage() {
|
||||||
toolName: String(p.toolName),
|
toolName: String(p.toolName),
|
||||||
args: (p.args as Record<string, unknown>) ?? {},
|
args: (p.args as Record<string, unknown>) ?? {},
|
||||||
result: p.result as unknown,
|
result: p.result as unknown,
|
||||||
|
// Restore argsText so persisted pretty-printed
|
||||||
|
// JSON survives reloads (assistant-ui prefers
|
||||||
|
// supplied argsText over JSON.stringify(args)).
|
||||||
|
// langchainToolCallId restoration also fixes a
|
||||||
|
// pre-existing dropped-id bug on resume.
|
||||||
|
...(typeof p.argsText === "string" ? { argsText: p.argsText } : {}),
|
||||||
|
...(typeof p.langchainToolCallId === "string"
|
||||||
|
? { langchainToolCallId: p.langchainToolCallId }
|
||||||
|
: {}),
|
||||||
});
|
});
|
||||||
contentPartsState.currentTextPartIndex = -1;
|
contentPartsState.currentTextPartIndex = -1;
|
||||||
} else if (p.type === "data-thinking-steps") {
|
} else if (p.type === "data-thinking-steps") {
|
||||||
|
|
@ -1580,7 +1605,12 @@ export default function NewChatPage() {
|
||||||
const editedAction = decisions[0].edited_action;
|
const editedAction = decisions[0].edited_action;
|
||||||
for (const part of contentParts) {
|
for (const part of contentParts) {
|
||||||
if (part.type === "tool-call" && part.toolName === editedAction.name) {
|
if (part.type === "tool-call" && part.toolName === editedAction.name) {
|
||||||
part.args = { ...part.args, ...editedAction.args };
|
const mergedArgs = { ...part.args, ...editedAction.args };
|
||||||
|
part.args = mergedArgs;
|
||||||
|
// Sync argsText so the rendered card shows the
|
||||||
|
// edited inputs — assistant-ui prefers caller-
|
||||||
|
// supplied argsText over JSON.stringify(args).
|
||||||
|
part.argsText = JSON.stringify(mergedArgs, null, 2);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1637,6 +1667,10 @@ export default function NewChatPage() {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
const scheduleFlush = () => batcher.schedule(flushMessages);
|
const scheduleFlush = () => batcher.schedule(flushMessages);
|
||||||
|
const forceFlush = () => {
|
||||||
|
scheduleFlush();
|
||||||
|
batcher.flush();
|
||||||
|
};
|
||||||
|
|
||||||
for await (const parsed of readSSEStream(response)) {
|
for await (const parsed of readSSEStream(response)) {
|
||||||
switch (parsed.type) {
|
switch (parsed.type) {
|
||||||
|
|
@ -1673,13 +1707,20 @@ export default function NewChatPage() {
|
||||||
false,
|
false,
|
||||||
parsed.langchainToolCallId
|
parsed.langchainToolCallId
|
||||||
);
|
);
|
||||||
batcher.flush();
|
forceFlush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "tool-input-available":
|
case "tool-input-delta":
|
||||||
|
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "tool-input-available": {
|
||||||
|
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
|
||||||
if (toolCallIndices.has(parsed.toolCallId)) {
|
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
args: parsed.input || {},
|
args: parsed.input || {},
|
||||||
|
argsText: finalArgsText,
|
||||||
langchainToolCallId: parsed.langchainToolCallId,
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -1692,9 +1733,13 @@ export default function NewChatPage() {
|
||||||
false,
|
false,
|
||||||
parsed.langchainToolCallId
|
parsed.langchainToolCallId
|
||||||
);
|
);
|
||||||
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
argsText: finalArgsText,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
batcher.flush();
|
forceFlush();
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
case "tool-output-available":
|
case "tool-output-available":
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
|
@ -1702,7 +1747,7 @@ export default function NewChatPage() {
|
||||||
langchainToolCallId: parsed.langchainToolCallId,
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
});
|
});
|
||||||
markInterruptsCompleted(contentParts);
|
markInterruptsCompleted(contentParts);
|
||||||
batcher.flush();
|
forceFlush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "data-thinking-step": {
|
case "data-thinking-step": {
|
||||||
|
|
@ -1762,34 +1807,17 @@ export default function NewChatPage() {
|
||||||
}
|
}
|
||||||
|
|
||||||
case "data-action-log": {
|
case "data-action-log": {
|
||||||
const al = parsed.data;
|
applyActionLogSse(queryClient, resumeThreadId, searchSpaceId, parsed.data);
|
||||||
const matchedToolCallId = al.lc_tool_call_id
|
|
||||||
? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id)
|
|
||||||
: null;
|
|
||||||
upsertAgentAction({
|
|
||||||
action: {
|
|
||||||
id: al.id,
|
|
||||||
threadId: resumeThreadId,
|
|
||||||
lcToolCallId: al.lc_tool_call_id,
|
|
||||||
chatTurnId: al.chat_turn_id,
|
|
||||||
toolName: al.tool_name,
|
|
||||||
reversible: al.reversible,
|
|
||||||
reverseDescriptorPresent: al.reverse_descriptor_present,
|
|
||||||
error: al.error,
|
|
||||||
revertedByActionId: null,
|
|
||||||
isRevertAction: false,
|
|
||||||
createdAt: al.created_at,
|
|
||||||
},
|
|
||||||
toolCallId: matchedToolCallId,
|
|
||||||
});
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case "data-action-log-updated": {
|
case "data-action-log-updated": {
|
||||||
updateAgentActionReversible({
|
applyActionLogUpdatedSse(
|
||||||
id: parsed.data.id,
|
queryClient,
|
||||||
reversible: parsed.data.reversible,
|
resumeThreadId,
|
||||||
});
|
parsed.data.id,
|
||||||
|
parsed.data.reversible
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1902,6 +1930,11 @@ export default function NewChatPage() {
|
||||||
return {
|
return {
|
||||||
...part,
|
...part,
|
||||||
args: decision.edited_action.args, // Update displayed args
|
args: decision.edited_action.args, // Update displayed args
|
||||||
|
// Sync argsText so the rendered card shows
|
||||||
|
// the edited inputs — assistant-ui prefers
|
||||||
|
// caller-supplied argsText over
|
||||||
|
// JSON.stringify(args).
|
||||||
|
argsText: JSON.stringify(decision.edited_action.args, null, 2),
|
||||||
result: {
|
result: {
|
||||||
...(part.result as Record<string, unknown>),
|
...(part.result as Record<string, unknown>),
|
||||||
__decided__: decisionType,
|
__decided__: decisionType,
|
||||||
|
|
@ -2127,6 +2160,10 @@ export default function NewChatPage() {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
const scheduleFlush = () => batcher.schedule(flushMessages);
|
const scheduleFlush = () => batcher.schedule(flushMessages);
|
||||||
|
const forceFlush = () => {
|
||||||
|
scheduleFlush();
|
||||||
|
batcher.flush();
|
||||||
|
};
|
||||||
|
|
||||||
for await (const parsed of readSSEStream(response)) {
|
for await (const parsed of readSSEStream(response)) {
|
||||||
switch (parsed.type) {
|
switch (parsed.type) {
|
||||||
|
|
@ -2163,13 +2200,20 @@ export default function NewChatPage() {
|
||||||
false,
|
false,
|
||||||
parsed.langchainToolCallId
|
parsed.langchainToolCallId
|
||||||
);
|
);
|
||||||
batcher.flush();
|
forceFlush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "tool-input-available":
|
case "tool-input-delta":
|
||||||
|
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "tool-input-available": {
|
||||||
|
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
|
||||||
if (toolCallIndices.has(parsed.toolCallId)) {
|
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
args: parsed.input || {},
|
args: parsed.input || {},
|
||||||
|
argsText: finalArgsText,
|
||||||
langchainToolCallId: parsed.langchainToolCallId,
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -2182,9 +2226,13 @@ export default function NewChatPage() {
|
||||||
false,
|
false,
|
||||||
parsed.langchainToolCallId
|
parsed.langchainToolCallId
|
||||||
);
|
);
|
||||||
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
argsText: finalArgsText,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
batcher.flush();
|
forceFlush();
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
case "tool-output-available":
|
case "tool-output-available":
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
|
@ -2201,7 +2249,7 @@ export default function NewChatPage() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
batcher.flush();
|
forceFlush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "data-thinking-step": {
|
case "data-thinking-step": {
|
||||||
|
|
@ -2217,34 +2265,21 @@ export default function NewChatPage() {
|
||||||
}
|
}
|
||||||
|
|
||||||
case "data-action-log": {
|
case "data-action-log": {
|
||||||
const al = parsed.data;
|
if (threadId !== null) {
|
||||||
const matchedToolCallId = al.lc_tool_call_id
|
applyActionLogSse(queryClient, threadId, searchSpaceId, parsed.data);
|
||||||
? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id)
|
}
|
||||||
: null;
|
|
||||||
upsertAgentAction({
|
|
||||||
action: {
|
|
||||||
id: al.id,
|
|
||||||
threadId,
|
|
||||||
lcToolCallId: al.lc_tool_call_id,
|
|
||||||
chatTurnId: al.chat_turn_id,
|
|
||||||
toolName: al.tool_name,
|
|
||||||
reversible: al.reversible,
|
|
||||||
reverseDescriptorPresent: al.reverse_descriptor_present,
|
|
||||||
error: al.error,
|
|
||||||
revertedByActionId: null,
|
|
||||||
isRevertAction: false,
|
|
||||||
createdAt: al.created_at,
|
|
||||||
},
|
|
||||||
toolCallId: matchedToolCallId,
|
|
||||||
});
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case "data-action-log-updated": {
|
case "data-action-log-updated": {
|
||||||
updateAgentActionReversible({
|
if (threadId !== null) {
|
||||||
id: parsed.data.id,
|
applyActionLogUpdatedSse(
|
||||||
reversible: parsed.data.reversible,
|
queryClient,
|
||||||
});
|
threadId,
|
||||||
|
parsed.data.id,
|
||||||
|
parsed.data.reversible
|
||||||
|
);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2281,12 +2316,16 @@ export default function NewChatPage() {
|
||||||
: `Reverted ${summary.reverted} downstream actions before regenerating.`
|
: `Reverted ${summary.reverted} downstream actions before regenerating.`
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
for (const r of summary.results) {
|
if (threadId !== null) {
|
||||||
if (r.status === "reverted" || r.status === "already_reverted") {
|
for (const r of summary.results) {
|
||||||
markAgentActionReverted({
|
if (r.status === "reverted" || r.status === "already_reverted") {
|
||||||
id: r.action_id,
|
markActionRevertedInCache(
|
||||||
newActionId: r.new_action_id ?? null,
|
queryClient,
|
||||||
});
|
threadId,
|
||||||
|
r.action_id,
|
||||||
|
r.new_action_id ?? null
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
@ -2459,16 +2498,26 @@ export default function NewChatPage() {
|
||||||
const downstream = messages.slice(editedIndex + 1);
|
const downstream = messages.slice(editedIndex + 1);
|
||||||
downstreamTotalCount = downstream.length;
|
downstreamTotalCount = downstream.length;
|
||||||
const seenTurns = new Set<string>();
|
const seenTurns = new Set<string>();
|
||||||
|
const downstreamTurnIds = new Set<string>();
|
||||||
for (const m of downstream) {
|
for (const m of downstream) {
|
||||||
const meta = (m.metadata ?? {}) as { custom?: { chatTurnId?: string } };
|
const meta = (m.metadata ?? {}) as { custom?: { chatTurnId?: string } };
|
||||||
const tid = meta.custom?.chatTurnId;
|
const tid = meta.custom?.chatTurnId;
|
||||||
if (!tid || seenTurns.has(tid)) continue;
|
if (!tid || seenTurns.has(tid)) continue;
|
||||||
seenTurns.add(tid);
|
seenTurns.add(tid);
|
||||||
const turnActions = agentActionsByChatTurnId.get(tid) ?? [];
|
downstreamTurnIds.add(tid);
|
||||||
for (const a of turnActions) {
|
}
|
||||||
if (a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error) {
|
// Source of truth: the unified react-query cache. Every
|
||||||
downstreamReversibleCount += 1;
|
// action whose ``chat_turn_id`` belongs to the slice we're
|
||||||
}
|
// about to drop counts toward the prompt.
|
||||||
|
for (const a of agentActionItems) {
|
||||||
|
if (!a.chat_turn_id || !downstreamTurnIds.has(a.chat_turn_id)) continue;
|
||||||
|
if (
|
||||||
|
a.reversible &&
|
||||||
|
(a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) &&
|
||||||
|
!a.is_revert_action &&
|
||||||
|
(a.error === null || a.error === undefined)
|
||||||
|
) {
|
||||||
|
downstreamReversibleCount += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -2492,7 +2541,7 @@ export default function NewChatPage() {
|
||||||
downstreamTotalCount,
|
downstreamTotalCount,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[handleRegenerate, messages, agentActionsByChatTurnId]
|
[handleRegenerate, messages, agentActionItems]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleEditDialogChoice = useCallback(
|
const handleEditDialogChoice = useCallback(
|
||||||
|
|
|
||||||
|
|
@ -1,194 +0,0 @@
|
||||||
"use client";
|
|
||||||
|
|
||||||
import { atom } from "jotai";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Minimal per-row projection of ``AgentActionLog`` that the tool card
|
|
||||||
* needs to decide whether to render a Revert button.
|
|
||||||
*
|
|
||||||
* Fields are deliberately a subset of the full ``AgentAction`` so the
|
|
||||||
* SSE side-channel (``data-action-log`` / ``data-action-log-updated``)
|
|
||||||
* can populate them without depending on the REST endpoint
|
|
||||||
* ``GET /threads/.../actions`` (which 503s when
|
|
||||||
* ``SURFSENSE_ENABLE_ACTION_LOG`` is off).
|
|
||||||
*/
|
|
||||||
export interface AgentActionLite {
|
|
||||||
id: number;
|
|
||||||
threadId: number | null;
|
|
||||||
lcToolCallId: string | null;
|
|
||||||
chatTurnId: string | null;
|
|
||||||
toolName: string;
|
|
||||||
reversible: boolean;
|
|
||||||
reverseDescriptorPresent: boolean;
|
|
||||||
error: boolean;
|
|
||||||
revertedByActionId: number | null;
|
|
||||||
isRevertAction: boolean;
|
|
||||||
createdAt: string | null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Map keyed off the LangChain ``tool_call.id`` (mirrors ``ContentPart
|
|
||||||
* tool-call.langchainToolCallId``).
|
|
||||||
*/
|
|
||||||
export const agentActionByLcIdAtom = atom<Map<string, AgentActionLite>>(new Map());
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parallel map keyed off the synthetic chat-card ``toolCallId``
|
|
||||||
* (``call_<run-id>``) so ``ToolFallback`` (which only receives the
|
|
||||||
* synthetic id from assistant-ui) can join its card to the action log.
|
|
||||||
*
|
|
||||||
* Both maps are kept in sync by ``upsertAgentActionAtom``.
|
|
||||||
*/
|
|
||||||
export const agentActionByToolCallIdAtom = atom<Map<string, AgentActionLite>>(new Map());
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Index keyed by ``chat_turn_id`` so the per-turn revert UI can answer
|
|
||||||
* "how many reversible actions does this assistant turn contain?" in
|
|
||||||
* O(1). Each entry's array is ordered by insertion (which
|
|
||||||
* for a single turn matches ``created_at`` because action-log writes
|
|
||||||
* happen synchronously).
|
|
||||||
*/
|
|
||||||
export const agentActionsByChatTurnIdAtom = atom<Map<string, AgentActionLite[]>>(new Map());
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Action to upsert one ``AgentActionLite`` row.
|
|
||||||
*
|
|
||||||
* ``toolCallId`` is the synthetic card id (``call_<run-id>`` from
|
|
||||||
* ``stream_new_chat.py``). When provided alongside ``lcToolCallId``, the
|
|
||||||
* action is indexed under BOTH ids so the tool card can perform the
|
|
||||||
* lookup without going via the streaming state.
|
|
||||||
*/
|
|
||||||
export const upsertAgentActionAtom = atom(
|
|
||||||
null,
|
|
||||||
(_get, set, payload: { action: AgentActionLite; toolCallId?: string | null }) => {
|
|
||||||
const { action, toolCallId } = payload;
|
|
||||||
const upsertInto = (
|
|
||||||
prev: Map<string, AgentActionLite>,
|
|
||||||
key: string
|
|
||||||
): Map<string, AgentActionLite> => {
|
|
||||||
const next = new Map(prev);
|
|
||||||
const existing = next.get(key);
|
|
||||||
next.set(key, {
|
|
||||||
...action,
|
|
||||||
// Preserve the local "reverted" bookkeeping if a reversibility
|
|
||||||
// flip arrives AFTER the user already reverted via the REST
|
|
||||||
// route. We never want a stale ``reversible=true`` event to
|
|
||||||
// resurrect a Reverted card.
|
|
||||||
revertedByActionId: existing?.revertedByActionId ?? action.revertedByActionId,
|
|
||||||
isRevertAction: existing?.isRevertAction ?? action.isRevertAction,
|
|
||||||
});
|
|
||||||
return next;
|
|
||||||
};
|
|
||||||
if (action.lcToolCallId) {
|
|
||||||
set(agentActionByLcIdAtom, (prev) => upsertInto(prev, action.lcToolCallId as string));
|
|
||||||
}
|
|
||||||
if (toolCallId) {
|
|
||||||
set(agentActionByToolCallIdAtom, (prev) => upsertInto(prev, toolCallId));
|
|
||||||
}
|
|
||||||
if (action.chatTurnId) {
|
|
||||||
set(agentActionsByChatTurnIdAtom, (prev) => {
|
|
||||||
const next = new Map(prev);
|
|
||||||
const turnId = action.chatTurnId as string;
|
|
||||||
const existing = next.get(turnId) ?? [];
|
|
||||||
const priorEntry = existing.find((row) => row.id === action.id);
|
|
||||||
const merged: AgentActionLite = {
|
|
||||||
...action,
|
|
||||||
revertedByActionId: priorEntry?.revertedByActionId ?? action.revertedByActionId,
|
|
||||||
isRevertAction: priorEntry?.isRevertAction ?? action.isRevertAction,
|
|
||||||
};
|
|
||||||
const others = existing.filter((row) => row.id !== action.id);
|
|
||||||
next.set(turnId, [...others, merged]);
|
|
||||||
return next;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
function mutateById(
|
|
||||||
prev: Map<string, AgentActionLite>,
|
|
||||||
id: number,
|
|
||||||
mutator: (entry: AgentActionLite) => AgentActionLite
|
|
||||||
): Map<string, AgentActionLite> {
|
|
||||||
let mutated = false;
|
|
||||||
const next = new Map(prev);
|
|
||||||
for (const [key, value] of next) {
|
|
||||||
if (value.id === id) {
|
|
||||||
next.set(key, mutator(value));
|
|
||||||
mutated = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return mutated ? next : prev;
|
|
||||||
}
|
|
||||||
|
|
||||||
function mutateByIdInTurnIndex(
|
|
||||||
prev: Map<string, AgentActionLite[]>,
|
|
||||||
id: number,
|
|
||||||
mutator: (entry: AgentActionLite) => AgentActionLite
|
|
||||||
): Map<string, AgentActionLite[]> {
|
|
||||||
let mutated = false;
|
|
||||||
const next = new Map(prev);
|
|
||||||
for (const [key, list] of next) {
|
|
||||||
let listMutated = false;
|
|
||||||
const updated = list.map((row) => {
|
|
||||||
if (row.id === id) {
|
|
||||||
listMutated = true;
|
|
||||||
return mutator(row);
|
|
||||||
}
|
|
||||||
return row;
|
|
||||||
});
|
|
||||||
if (listMutated) {
|
|
||||||
next.set(key, updated);
|
|
||||||
mutated = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return mutated ? next : prev;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Action to flip an existing entry's ``reversible`` flag, keyed by the
|
|
||||||
* AgentActionLog row id (the SSE ``data-action-log-updated`` payload
|
|
||||||
* does NOT carry ``lcToolCallId``).
|
|
||||||
*/
|
|
||||||
export const updateAgentActionReversibleAtom = atom(
|
|
||||||
null,
|
|
||||||
(_get, set, payload: { id: number; reversible: boolean }) => {
|
|
||||||
const apply = (entry: AgentActionLite): AgentActionLite => ({
|
|
||||||
...entry,
|
|
||||||
reversible: payload.reversible,
|
|
||||||
});
|
|
||||||
set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply));
|
|
||||||
set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply));
|
|
||||||
set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply));
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
/** Action to mark an existing entry as reverted (post-revert call). */
|
|
||||||
export const markAgentActionRevertedAtom = atom(
|
|
||||||
null,
|
|
||||||
(_get, set, payload: { id: number; newActionId: number | null }) => {
|
|
||||||
const apply = (entry: AgentActionLite): AgentActionLite => ({
|
|
||||||
...entry,
|
|
||||||
revertedByActionId: payload.newActionId ?? -1,
|
|
||||||
});
|
|
||||||
set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply));
|
|
||||||
set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply));
|
|
||||||
set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply));
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
/** Mark every action in a turn as reverted, given a list of (id, newActionId) pairs. */
|
|
||||||
export const markAgentActionsRevertedBatchAtom = atom(
|
|
||||||
null,
|
|
||||||
(_get, set, payload: { entries: Array<{ id: number; newActionId: number | null }> }) => {
|
|
||||||
for (const entry of payload.entries) {
|
|
||||||
set(markAgentActionRevertedAtom, entry);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
/** Reset all maps (e.g. when the active thread changes). */
|
|
||||||
export const resetAgentActionMapAtom = atom(null, (_get, set) => {
|
|
||||||
set(agentActionByLcIdAtom, new Map());
|
|
||||||
set(agentActionByToolCallIdAtom, new Map());
|
|
||||||
set(agentActionsByChatTurnIdAtom, new Map());
|
|
||||||
});
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useQuery, useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { useAtom, useAtomValue } from "jotai";
|
import { useAtom, useAtomValue } from "jotai";
|
||||||
import { Activity, RefreshCcw } from "lucide-react";
|
import { Activity, RefreshCcw } from "lucide-react";
|
||||||
import { useCallback, useMemo } from "react";
|
import { useCallback } from "react";
|
||||||
import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom";
|
import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom";
|
||||||
import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom";
|
import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom";
|
||||||
import { Badge } from "@/components/ui/badge";
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
|
@ -17,15 +17,12 @@ import {
|
||||||
SheetTitle,
|
SheetTitle,
|
||||||
} from "@/components/ui/sheet";
|
} from "@/components/ui/sheet";
|
||||||
import { Skeleton } from "@/components/ui/skeleton";
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
|
import {
|
||||||
|
agentActionsQueryKey,
|
||||||
|
useAgentActionsQuery,
|
||||||
|
} from "@/hooks/use-agent-actions-query";
|
||||||
import { ActionLogItem } from "./action-log-item";
|
import { ActionLogItem } from "./action-log-item";
|
||||||
|
|
||||||
const ACTION_LOG_PAGE_SIZE = 50;
|
|
||||||
|
|
||||||
function actionLogQueryKey(threadId: number) {
|
|
||||||
return ["agent-actions", threadId] as const;
|
|
||||||
}
|
|
||||||
|
|
||||||
function EmptyState() {
|
function EmptyState() {
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-6 text-center">
|
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-6 text-center">
|
||||||
|
|
@ -85,25 +82,17 @@ export function ActionLogSheet() {
|
||||||
|
|
||||||
const threadId = state.threadId;
|
const threadId = state.threadId;
|
||||||
|
|
||||||
const { data, isLoading, isFetching, isError, error, refetch } = useQuery({
|
const { data, items, isLoading, isFetching, isError, error, refetch } = useAgentActionsQuery(
|
||||||
queryKey: threadId !== null ? actionLogQueryKey(threadId) : ["agent-actions", "none"],
|
threadId,
|
||||||
queryFn: () =>
|
{ enabled: state.open && actionLogEnabled }
|
||||||
agentActionsApiService.listForThread(threadId as number, {
|
);
|
||||||
page: 0,
|
|
||||||
pageSize: ACTION_LOG_PAGE_SIZE,
|
|
||||||
}),
|
|
||||||
enabled: state.open && threadId !== null && actionLogEnabled,
|
|
||||||
staleTime: 15 * 1000,
|
|
||||||
});
|
|
||||||
|
|
||||||
const handleRevertSuccess = useCallback(() => {
|
const handleRevertSuccess = useCallback(() => {
|
||||||
if (threadId !== null) {
|
if (threadId !== null) {
|
||||||
queryClient.invalidateQueries({ queryKey: actionLogQueryKey(threadId) });
|
queryClient.invalidateQueries({ queryKey: agentActionsQueryKey(threadId) });
|
||||||
}
|
}
|
||||||
}, [queryClient, threadId]);
|
}, [queryClient, threadId]);
|
||||||
|
|
||||||
const items = useMemo(() => data?.items ?? [], [data]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Sheet open={state.open} onOpenChange={(open) => setState((s) => ({ ...s, open }))}>
|
<Sheet open={state.open} onOpenChange={(open) => setState((s) => ({ ...s, open }))}>
|
||||||
<SheetContent
|
<SheetContent
|
||||||
|
|
|
||||||
|
|
@ -4,26 +4,22 @@
|
||||||
* "Revert turn" button rendered at the bottom of every completed
|
* "Revert turn" button rendered at the bottom of every completed
|
||||||
* assistant turn that has at least one reversible action.
|
* assistant turn that has at least one reversible action.
|
||||||
*
|
*
|
||||||
* The button reads the action map keyed by ``chat_turn_id`` from the
|
* The button reads from the unified ``useAgentActionsQuery`` cache
|
||||||
* SSE side-channel (``data-action-log`` events). It shows a confirmation
|
* (the SAME react-query cache the agent-actions sheet and the inline
|
||||||
* dialog summarising "N reversible / M total" and, on confirm, calls
|
* Revert button consume) filtered by ``chat_turn_id``. It shows a
|
||||||
* ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
|
* confirmation dialog summarising "N reversible / M total" and, on
|
||||||
|
* confirm, calls ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
|
||||||
*
|
*
|
||||||
* The route returns a per-action result list and never collapses the
|
* The route returns a per-action result list and never collapses the
|
||||||
* batch into a 4xx — so we render any failed/not_reversible rows inline
|
* batch into a 4xx — so we render any failed/not_reversible rows inline
|
||||||
* with their messages.
|
* with their messages.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { useAtomValue, useSetAtom } from "jotai";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { selectAtom } from "jotai/utils";
|
import { useAtomValue } from "jotai";
|
||||||
import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react";
|
import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react";
|
||||||
import { useMemo, useState } from "react";
|
import { useMemo, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import {
|
|
||||||
type AgentActionLite,
|
|
||||||
agentActionsByChatTurnIdAtom,
|
|
||||||
markAgentActionsRevertedBatchAtom,
|
|
||||||
} from "@/atoms/chat/agent-actions.atom";
|
|
||||||
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
||||||
import {
|
import {
|
||||||
AlertDialog,
|
AlertDialog,
|
||||||
|
|
@ -38,6 +34,10 @@ import {
|
||||||
} from "@/components/ui/alert-dialog";
|
} from "@/components/ui/alert-dialog";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
|
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
|
||||||
|
import {
|
||||||
|
applyRevertTurnResultsToCache,
|
||||||
|
useAgentActionsQuery,
|
||||||
|
} from "@/hooks/use-agent-actions-query";
|
||||||
import {
|
import {
|
||||||
agentActionsApiService,
|
agentActionsApiService,
|
||||||
type RevertTurnActionResult,
|
type RevertTurnActionResult,
|
||||||
|
|
@ -49,49 +49,33 @@ interface RevertTurnButtonProps {
|
||||||
chatTurnId: string | null | undefined;
|
chatTurnId: string | null | undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Empty-array sentinel so the per-turn ``selectAtom`` slice returns a
|
|
||||||
// stable reference when the turn has no recorded actions yet. Without
|
|
||||||
// this every render allocates a fresh ``[]`` and Jotai's
|
|
||||||
// equality check would re-render the button on unrelated turn updates.
|
|
||||||
const EMPTY_ACTIONS: readonly AgentActionLite[] = Object.freeze([]);
|
|
||||||
|
|
||||||
export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) {
|
export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) {
|
||||||
const session = useAtomValue(chatSessionStateAtom);
|
const session = useAtomValue(chatSessionStateAtom);
|
||||||
const markRevertedBatch = useSetAtom(markAgentActionsRevertedBatchAtom);
|
const threadId = session?.threadId ?? null;
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
const { findByChatTurnId } = useAgentActionsQuery(threadId);
|
||||||
const [isReverting, setIsReverting] = useState(false);
|
const [isReverting, setIsReverting] = useState(false);
|
||||||
const [confirmOpen, setConfirmOpen] = useState(false);
|
const [confirmOpen, setConfirmOpen] = useState(false);
|
||||||
const [resultsOpen, setResultsOpen] = useState(false);
|
const [resultsOpen, setResultsOpen] = useState(false);
|
||||||
const [results, setResults] = useState<RevertTurnActionResult[]>([]);
|
const [results, setResults] = useState<RevertTurnActionResult[]>([]);
|
||||||
|
|
||||||
// Subscribe ONLY to the slice of the global action map that belongs
|
const actions = useMemo(() => findByChatTurnId(chatTurnId), [findByChatTurnId, chatTurnId]);
|
||||||
// to ``chatTurnId``. Previously the button read the whole
|
|
||||||
// ``agentActionsByChatTurnIdAtom``, which meant every action
|
|
||||||
// upsert (one per tool call) re-rendered every Revert button on
|
|
||||||
// the page. With ``selectAtom`` we re-render only when our turn's
|
|
||||||
// list reference changes — and the upsert/mark atoms produce a
|
|
||||||
// fresh list reference for the affected turn only.
|
|
||||||
const sliceAtom = useMemo(
|
|
||||||
() =>
|
|
||||||
selectAtom(
|
|
||||||
agentActionsByChatTurnIdAtom,
|
|
||||||
(turnIndex) => (chatTurnId ? turnIndex.get(chatTurnId) : undefined) ?? EMPTY_ACTIONS
|
|
||||||
),
|
|
||||||
[chatTurnId]
|
|
||||||
);
|
|
||||||
const actions = useAtomValue(sliceAtom);
|
|
||||||
|
|
||||||
const reversibleCount = useMemo(
|
const reversibleCount = useMemo(
|
||||||
() =>
|
() =>
|
||||||
actions.filter(
|
actions.filter(
|
||||||
(a) => a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error
|
(a) =>
|
||||||
|
a.reversible &&
|
||||||
|
(a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) &&
|
||||||
|
!a.is_revert_action &&
|
||||||
|
(a.error === null || a.error === undefined)
|
||||||
).length,
|
).length,
|
||||||
[actions]
|
[actions]
|
||||||
);
|
);
|
||||||
const totalCount = useMemo(() => actions.filter((a) => !a.isRevertAction).length, [actions]);
|
const totalCount = useMemo(() => actions.filter((a) => !a.is_revert_action).length, [actions]);
|
||||||
|
|
||||||
if (!chatTurnId) return null;
|
if (!chatTurnId) return null;
|
||||||
if (reversibleCount === 0) return null;
|
if (reversibleCount === 0) return null;
|
||||||
const threadId = session?.threadId;
|
|
||||||
if (!threadId) return null;
|
if (!threadId) return null;
|
||||||
|
|
||||||
const handleRevertTurn = async () => {
|
const handleRevertTurn = async () => {
|
||||||
|
|
@ -103,7 +87,7 @@ export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) {
|
||||||
.filter((r) => r.status === "reverted" || r.status === "already_reverted")
|
.filter((r) => r.status === "reverted" || r.status === "already_reverted")
|
||||||
.map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null }));
|
.map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null }));
|
||||||
if (revertedEntries.length > 0) {
|
if (revertedEntries.length > 0) {
|
||||||
markRevertedBatch({ entries: revertedEntries });
|
applyRevertTurnResultsToCache(queryClient, threadId, revertedEntries);
|
||||||
}
|
}
|
||||||
if (response.status === "ok") {
|
if (response.status === "ok") {
|
||||||
toast.success(
|
toast.success(
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
|
|
||||||
import { useAtomValue, useSetAtom } from "jotai";
|
|
||||||
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, RotateCcw, XCircleIcon } from "lucide-react";
|
|
||||||
import { useMemo, useState } from "react";
|
|
||||||
import { toast } from "sonner";
|
|
||||||
import {
|
import {
|
||||||
agentActionByToolCallIdAtom,
|
type ToolCallMessagePartComponent,
|
||||||
markAgentActionRevertedAtom,
|
useAuiState,
|
||||||
} from "@/atoms/chat/agent-actions.atom";
|
} from "@assistant-ui/react";
|
||||||
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
|
import { useAtomValue } from "jotai";
|
||||||
|
import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react";
|
||||||
|
import { useEffect, useMemo, useState } from "react";
|
||||||
|
import { toast } from "sonner";
|
||||||
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
||||||
import {
|
import {
|
||||||
DoomLoopApprovalToolUI,
|
DoomLoopApprovalToolUI,
|
||||||
|
|
@ -24,8 +24,17 @@ import {
|
||||||
AlertDialogTitle,
|
AlertDialogTitle,
|
||||||
AlertDialogTrigger,
|
AlertDialogTrigger,
|
||||||
} from "@/components/ui/alert-dialog";
|
} from "@/components/ui/alert-dialog";
|
||||||
|
import { Badge } from "@/components/ui/badge";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons";
|
import { Card } from "@/components/ui/card";
|
||||||
|
import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible";
|
||||||
|
import { Separator } from "@/components/ui/separator";
|
||||||
|
import { Spinner } from "@/components/ui/spinner";
|
||||||
|
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
|
||||||
|
import {
|
||||||
|
markActionRevertedInCache,
|
||||||
|
useAgentActionsQuery,
|
||||||
|
} from "@/hooks/use-agent-actions-query";
|
||||||
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
|
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
|
||||||
import { AppError } from "@/lib/error";
|
import { AppError } from "@/lib/error";
|
||||||
import { isInterruptResult } from "@/lib/hitl";
|
import { isInterruptResult } from "@/lib/hitl";
|
||||||
|
|
@ -34,31 +43,128 @@ import { cn } from "@/lib/utils";
|
||||||
/**
|
/**
|
||||||
* Inline Revert button rendered on a tool card when the matching
|
* Inline Revert button rendered on a tool card when the matching
|
||||||
* ``AgentActionLog`` row is reversible and hasn't been reverted yet.
|
* ``AgentActionLog`` row is reversible and hasn't been reverted yet.
|
||||||
* Reads from the SSE side-channel atom keyed by the synthetic
|
*
|
||||||
* ``toolCallId`` so it lights up even when ``GET /threads/.../actions``
|
* Reads from the unified ``useAgentActionsQuery`` cache — the SAME
|
||||||
* is gated behind ``SURFSENSE_ENABLE_ACTION_LOG=False`` (503).
|
* react-query cache the agent-actions sheet consumes. SSE events
|
||||||
|
* (``data-action-log`` / ``data-action-log-updated``) and
|
||||||
|
* ``POST /threads/{id}/revert/{id}`` responses both flow through the
|
||||||
|
* cache via ``setQueryData`` helpers, so the card and the sheet stay
|
||||||
|
* in lockstep on every code path: page reload, navigation, live
|
||||||
|
* stream, post-stream reversibility flip, and explicit revert clicks.
|
||||||
|
*
|
||||||
|
* Match key (in priority order):
|
||||||
|
* 1. ``a.tool_call_id === toolCallId`` — direct hit in parity_v2 when
|
||||||
|
* the model streamed ``tool_call_chunks`` so the card's synthetic
|
||||||
|
* id IS the LangChain id.
|
||||||
|
* 2. ``a.tool_call_id === langchainToolCallId`` — legacy mode (or
|
||||||
|
* parity_v2 with provider-side chunk emission) where the card's
|
||||||
|
* synthetic id is ``call_<run_id>`` and the LangChain id is
|
||||||
|
* backfilled onto the part by ``tool-output-available``.
|
||||||
|
* 3. ``(chat_turn_id, tool_name, position-within-turn)`` — fallback
|
||||||
|
* for cards whose synthetic id is ``call_<run_id>`` AND whose
|
||||||
|
* ``langchainToolCallId`` never got backfilled (provider emitted
|
||||||
|
* the tool_call as a single payload with no chunks AND streaming
|
||||||
|
* pre-dated the ``tool-output-available langchainToolCallId``
|
||||||
|
* backfill, e.g. older threads). Reads the parent message's
|
||||||
|
* ``chatTurnId`` and ``content`` via ``useAuiState`` so we can
|
||||||
|
* match position-by-tool-name within the turn against the
|
||||||
|
* action_log rows the server returned in ``created_at`` order.
|
||||||
*/
|
*/
|
||||||
function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) {
|
function ToolCardRevertButton({
|
||||||
|
toolCallId,
|
||||||
|
toolName,
|
||||||
|
langchainToolCallId,
|
||||||
|
}: {
|
||||||
|
toolCallId: string;
|
||||||
|
toolName: string;
|
||||||
|
langchainToolCallId?: string;
|
||||||
|
}) {
|
||||||
const session = useAtomValue(chatSessionStateAtom);
|
const session = useAtomValue(chatSessionStateAtom);
|
||||||
const actionMap = useAtomValue(agentActionByToolCallIdAtom);
|
const threadId = session?.threadId ?? null;
|
||||||
const markReverted = useSetAtom(markAgentActionRevertedAtom);
|
const queryClient = useQueryClient();
|
||||||
const action = actionMap.get(toolCallId);
|
const { findByToolCallId, findByChatTurnAndTool } = useAgentActionsQuery(threadId);
|
||||||
|
|
||||||
|
// Parent message metadata, read via the narrowest possible
|
||||||
|
// selectors so this card doesn't re-render on every text-delta of
|
||||||
|
// every other part in the same message during streaming.
|
||||||
|
//
|
||||||
|
// IMPORTANT — ``useAuiState`` re-renders the component whenever the
|
||||||
|
// returned slice's identity changes. Returning ``message?.content``
|
||||||
|
// (an array) would re-render on every token because the runtime
|
||||||
|
// rebuilds the parts array. Returning a PRIMITIVE (the position
|
||||||
|
// number) lets ``useAuiState``'s ``Object.is`` check short-circuit
|
||||||
|
// when the position hasn't actually moved — which is the common
|
||||||
|
// case during text streaming, when only ``text``/``reasoning``
|
||||||
|
// parts are mutating and the same-toolName tool-call ordering is
|
||||||
|
// stable. (See Vercel React rule ``rerender-defer-reads``.)
|
||||||
|
const chatTurnId = useAuiState(({ message }) => {
|
||||||
|
const meta = message?.metadata as { custom?: { chatTurnId?: string } } | undefined;
|
||||||
|
return meta?.custom?.chatTurnId ?? null;
|
||||||
|
});
|
||||||
|
const positionInTurn = useAuiState(({ message }) => {
|
||||||
|
const content = message?.content;
|
||||||
|
if (!Array.isArray(content)) return -1;
|
||||||
|
let n = -1;
|
||||||
|
for (const part of content) {
|
||||||
|
if (
|
||||||
|
part &&
|
||||||
|
typeof part === "object" &&
|
||||||
|
(part as { type?: string }).type === "tool-call" &&
|
||||||
|
(part as { toolName?: string }).toolName === toolName
|
||||||
|
) {
|
||||||
|
n += 1;
|
||||||
|
if ((part as { toolCallId?: string }).toolCallId === toolCallId) return n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
});
|
||||||
|
|
||||||
|
const action = useMemo(() => {
|
||||||
|
// Tier 1 + 2: O(1) Map-backed direct id match. Covers
|
||||||
|
// ~all parity_v2 streams and any legacy stream that backfilled
|
||||||
|
// ``langchainToolCallId`` via ``tool-output-available``.
|
||||||
|
const direct =
|
||||||
|
findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId);
|
||||||
|
if (direct) return direct;
|
||||||
|
// Tier 3: position-within-turn fallback. Only kicks in when the
|
||||||
|
// card has a synthetic ``call_<run_id>`` id AND no
|
||||||
|
// ``langchainToolCallId`` was ever backfilled — i.e. the tool
|
||||||
|
// was emitted as a single non-chunked payload AND streaming
|
||||||
|
// pre-dated the on_tool_end backfill.
|
||||||
|
if (!chatTurnId || positionInTurn < 0) return null;
|
||||||
|
const turnSameTool = findByChatTurnAndTool(chatTurnId, toolName);
|
||||||
|
return turnSameTool[positionInTurn] ?? null;
|
||||||
|
}, [
|
||||||
|
findByToolCallId,
|
||||||
|
findByChatTurnAndTool,
|
||||||
|
toolCallId,
|
||||||
|
langchainToolCallId,
|
||||||
|
chatTurnId,
|
||||||
|
toolName,
|
||||||
|
positionInTurn,
|
||||||
|
]);
|
||||||
|
|
||||||
const [isReverting, setIsReverting] = useState(false);
|
const [isReverting, setIsReverting] = useState(false);
|
||||||
const [confirmOpen, setConfirmOpen] = useState(false);
|
const [confirmOpen, setConfirmOpen] = useState(false);
|
||||||
|
|
||||||
if (!action) return null;
|
if (!action) return null;
|
||||||
if (!action.reversible) return null;
|
if (!action.reversible) return null;
|
||||||
if (action.revertedByActionId !== null) return null;
|
if (action.reverted_by_action_id !== null && action.reverted_by_action_id !== undefined)
|
||||||
if (action.isRevertAction) return null;
|
return null;
|
||||||
if (action.error) return null;
|
if (action.is_revert_action) return null;
|
||||||
const threadId = session?.threadId;
|
if (action.error !== null && action.error !== undefined) return null;
|
||||||
if (!threadId) return null;
|
if (!threadId) return null;
|
||||||
|
|
||||||
const handleRevert = async () => {
|
const handleRevert = async () => {
|
||||||
setIsReverting(true);
|
setIsReverting(true);
|
||||||
try {
|
try {
|
||||||
const response = await agentActionsApiService.revert(threadId, action.id);
|
const response = await agentActionsApiService.revert(threadId, action.id);
|
||||||
markReverted({ id: action.id, newActionId: response.new_action_id ?? null });
|
markActionRevertedInCache(
|
||||||
|
queryClient,
|
||||||
|
threadId,
|
||||||
|
action.id,
|
||||||
|
response.new_action_id ?? null
|
||||||
|
);
|
||||||
toast.success(response.message || "Action reverted.");
|
toast.success(response.message || "Action reverted.");
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
// 503 means revert is gated off on this deployment — hide the
|
// 503 means revert is gated off on this deployment — hide the
|
||||||
|
|
@ -91,8 +197,17 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) {
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
setConfirmOpen(true);
|
setConfirmOpen(true);
|
||||||
}}
|
}}
|
||||||
|
disabled={isReverting}
|
||||||
>
|
>
|
||||||
<RotateCcw className="size-3.5" />
|
{isReverting ? (
|
||||||
|
// Spinner's typed props don't accept ``data-icon`` and
|
||||||
|
// it renders an <output>, not an <svg>, so Button's
|
||||||
|
// auto-sizing rule doesn't apply. Bare spinner +
|
||||||
|
// Button's gap handle layout.
|
||||||
|
<Spinner size="xs" />
|
||||||
|
) : (
|
||||||
|
<RotateCcw data-icon="inline-start" />
|
||||||
|
)}
|
||||||
Revert
|
Revert
|
||||||
</Button>
|
</Button>
|
||||||
</AlertDialogTrigger>
|
</AlertDialogTrigger>
|
||||||
|
|
@ -101,7 +216,7 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) {
|
||||||
<AlertDialogTitle>Revert this action?</AlertDialogTitle>
|
<AlertDialogTitle>Revert this action?</AlertDialogTitle>
|
||||||
<AlertDialogDescription>
|
<AlertDialogDescription>
|
||||||
This will undo{" "}
|
This will undo{" "}
|
||||||
<span className="font-medium">{getToolDisplayName(action.toolName)}</span> and add a
|
<span className="font-medium">{getToolDisplayName(action.tool_name)}</span> and add a
|
||||||
new entry to the history. Your chat is preserved — only the changes the agent made to
|
new entry to the history. Your chat is preserved — only the changes the agent made to
|
||||||
your knowledge base or connected apps will be rolled back where possible.
|
your knowledge base or connected apps will be rolled back where possible.
|
||||||
</AlertDialogDescription>
|
</AlertDialogDescription>
|
||||||
|
|
@ -114,8 +229,10 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) {
|
||||||
handleRevert();
|
handleRevert();
|
||||||
}}
|
}}
|
||||||
disabled={isReverting}
|
disabled={isReverting}
|
||||||
|
className="gap-1.5"
|
||||||
>
|
>
|
||||||
{isReverting ? "Reverting…" : "Revert"}
|
{isReverting && <Spinner size="xs" />}
|
||||||
|
Revert
|
||||||
</AlertDialogAction>
|
</AlertDialogAction>
|
||||||
</AlertDialogFooter>
|
</AlertDialogFooter>
|
||||||
</AlertDialogContent>
|
</AlertDialogContent>
|
||||||
|
|
@ -123,18 +240,49 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
|
/**
|
||||||
toolCallId,
|
* Compact tool-call card.
|
||||||
toolName,
|
*
|
||||||
argsText,
|
* shadcn composition note: we intentionally use ``Card`` as a visual
|
||||||
result,
|
* frame WITHOUT ``CardHeader / CardContent``. The full composition's
|
||||||
status,
|
* ``p-6`` padding doesn't fit a compact collapsible header that IS the
|
||||||
}) => {
|
* trigger; using ``Card`` alone preserves the rounded border, shadow,
|
||||||
const [isExpanded, setIsExpanded] = useState(false);
|
* and ``bg-card`` token (semantic colors) without forcing a layout
|
||||||
|
* that doesn't fit. All status colors use semantic tokens — no manual
|
||||||
|
* dark-mode overrides, no raw hex.
|
||||||
|
*/
|
||||||
|
const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
|
||||||
|
const { toolCallId, toolName, argsText, result, status } = props;
|
||||||
|
// ``langchainToolCallId`` is a SurfSense-specific extension the
|
||||||
|
// streaming pipeline attaches to the tool-call content part so
|
||||||
|
// the Revert button can resolve its ``AgentActionLog`` row even
|
||||||
|
// when only the LC id is known. assistant-ui's
|
||||||
|
// ``ToolCallMessagePartProps`` doesn't list it, but the runtime
|
||||||
|
// spreads ``{...part}`` so the prop reaches us at runtime.
|
||||||
|
const langchainToolCallId = (props as { langchainToolCallId?: string }).langchainToolCallId;
|
||||||
|
|
||||||
const isCancelled = status?.type === "incomplete" && status.reason === "cancelled";
|
const isCancelled = status?.type === "incomplete" && status.reason === "cancelled";
|
||||||
const isError = status?.type === "incomplete" && status.reason === "error";
|
const isError = status?.type === "incomplete" && status.reason === "error";
|
||||||
const isRunning = status?.type === "running" || status?.type === "requires-action";
|
const isRunning = status?.type === "running" || status?.type === "requires-action";
|
||||||
|
|
||||||
|
/*
|
||||||
|
Per-card expansion state. Initial value is ``isRunning`` so a
|
||||||
|
card streaming in mounts already-expanded (no flash of
|
||||||
|
collapsed → expanded on first paint), while a card loaded from
|
||||||
|
history (status="complete") mounts collapsed. The useEffect
|
||||||
|
below keeps this in lockstep with this card's own ``isRunning``
|
||||||
|
when it transitions: false → true auto-expands (e.g. a tool
|
||||||
|
that re-runs after edit), true → false auto-collapses once the
|
||||||
|
tool finishes. Because the dep is per-card ``isRunning`` and
|
||||||
|
not the chat-level streaming flag, sibling cards on the same
|
||||||
|
assistant turn each manage their own expansion independently.
|
||||||
|
Once ``isRunning`` is false the user controls expansion via
|
||||||
|
``onOpenChange``.
|
||||||
|
*/
|
||||||
|
const [isExpanded, setIsExpanded] = useState(isRunning);
|
||||||
|
useEffect(() => {
|
||||||
|
setIsExpanded(isRunning);
|
||||||
|
}, [isRunning]);
|
||||||
const errorData = status?.type === "incomplete" ? status.error : undefined;
|
const errorData = status?.type === "incomplete" ? status.error : undefined;
|
||||||
const serializedError = useMemo(
|
const serializedError = useMemo(
|
||||||
() => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null),
|
() => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null),
|
||||||
|
|
@ -160,108 +308,207 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
|
||||||
: serializedError
|
: serializedError
|
||||||
: null;
|
: null;
|
||||||
|
|
||||||
const Icon = getToolIcon(toolName);
|
|
||||||
const displayName = getToolDisplayName(toolName);
|
const displayName = getToolDisplayName(toolName);
|
||||||
|
const subtitle = errorReason ?? cancelledReason;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<Card
|
||||||
className={cn(
|
className={cn(
|
||||||
"my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none",
|
"my-4 max-w-lg overflow-hidden",
|
||||||
isCancelled && "opacity-60",
|
isCancelled && "opacity-60",
|
||||||
isError && "border-destructive/20 bg-destructive/5"
|
isError && "border-destructive/30"
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<button
|
{/*
|
||||||
type="button"
|
``group`` lets the chevron (rendered as a sibling of the
|
||||||
onClick={() => setIsExpanded((prev) => !prev)}
|
main trigger button) read the Collapsible Root's
|
||||||
className="flex w-full items-center gap-3 px-5 py-4 text-left transition-colors hover:bg-muted/50 focus:outline-none focus-visible:outline-none"
|
``data-[state=open]`` for rotation. The Collapsible is
|
||||||
|
fully controlled via ``isExpanded`` — the useEffect
|
||||||
|
above syncs it to ``isRunning`` so the card auto-opens
|
||||||
|
while a tool streams in and auto-collapses once it
|
||||||
|
finishes. We deliberately DON'T pass ``disabled`` so
|
||||||
|
both triggers stay clickable; ``onOpenChange`` is wired
|
||||||
|
to a setter that no-ops while ``isRunning`` (see
|
||||||
|
``handleOpenChange`` below) which keeps the card pinned
|
||||||
|
open mid-stream without losing keyboard / pointer
|
||||||
|
affordance the moment streaming ends.
|
||||||
|
*/}
|
||||||
|
<Collapsible
|
||||||
|
className="group"
|
||||||
|
open={isExpanded}
|
||||||
|
onOpenChange={(next) => {
|
||||||
|
// Block manual collapse while the tool is still
|
||||||
|
// streaming — otherwise a stray click on either
|
||||||
|
// trigger would close the card and hide the live
|
||||||
|
// ``argsText`` panel mid-run. After streaming the
|
||||||
|
// user has full control again.
|
||||||
|
if (isRunning) return;
|
||||||
|
setIsExpanded(next);
|
||||||
|
}}
|
||||||
>
|
>
|
||||||
<div
|
{/*
|
||||||
className={cn(
|
Header row: main trigger on the left (icon + title
|
||||||
"flex size-8 shrink-0 items-center justify-center rounded-lg",
|
col), Revert + chevron-trigger on the right as
|
||||||
isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10"
|
siblings of the main trigger. The chevron is wrapped
|
||||||
)}
|
in its OWN ``CollapsibleTrigger`` (Radix supports
|
||||||
>
|
multiple triggers per Root) so clicking the chevron
|
||||||
{isError ? (
|
toggles the same state as clicking the title row.
|
||||||
<XCircleIcon className="size-4 text-destructive" />
|
The Revert button stays a separate AlertDialog
|
||||||
) : isCancelled ? (
|
trigger and stops propagation in its onClick so it
|
||||||
<XCircleIcon className="size-4 text-muted-foreground" />
|
doesn't toggle the collapsible while opening the
|
||||||
) : isRunning ? (
|
confirm dialog. Keeping these as flat siblings —
|
||||||
<Icon className="size-4 text-primary animate-pulse" />
|
rather than nesting Revert / chevron inside the
|
||||||
) : (
|
title trigger — avoids invalid HTML
|
||||||
<CheckIcon className="size-4 text-primary" />
|
(button-in-button) and lets the Revert button
|
||||||
)}
|
render in BOTH the collapsed and expanded states.
|
||||||
</div>
|
*/}
|
||||||
|
<div className="flex items-stretch transition-colors hover:bg-muted/50">
|
||||||
|
<CollapsibleTrigger asChild>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
className={cn(
|
||||||
|
"flex flex-1 min-w-0 items-center gap-3 py-4 pl-5 pr-2 text-left",
|
||||||
|
// Inset ring — Card's ``overflow-hidden`` would
|
||||||
|
// clip an ``offset-2`` ring; ``ring-inset``
|
||||||
|
// paints inside the button box.
|
||||||
|
"focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset",
|
||||||
|
"disabled:cursor-default"
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex size-8 shrink-0 items-center justify-center rounded-lg",
|
||||||
|
isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10"
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{isError ? (
|
||||||
|
<XCircleIcon className="size-4 text-destructive" />
|
||||||
|
) : isCancelled ? (
|
||||||
|
<XCircleIcon className="size-4 text-muted-foreground" />
|
||||||
|
) : isRunning ? (
|
||||||
|
<Spinner size="sm" className="text-primary" />
|
||||||
|
) : (
|
||||||
|
<CheckIcon className="size-4 text-primary" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
<div className="flex-1 min-w-0">
|
<div className="flex flex-1 min-w-0 flex-col gap-0.5">
|
||||||
<p
|
<div className="flex items-center gap-2">
|
||||||
className={cn(
|
<p
|
||||||
"text-sm font-semibold",
|
className={cn(
|
||||||
isError
|
"text-sm font-semibold truncate",
|
||||||
? "text-destructive"
|
isCancelled && "text-muted-foreground line-through",
|
||||||
: isCancelled
|
isError && "text-destructive"
|
||||||
? "text-muted-foreground line-through"
|
)}
|
||||||
: "text-foreground"
|
>
|
||||||
)}
|
{displayName}
|
||||||
>
|
</p>
|
||||||
{isRunning
|
{isRunning && <Badge variant="secondary">Running</Badge>}
|
||||||
? displayName
|
{isError && <Badge variant="destructive">Failed</Badge>}
|
||||||
: isCancelled
|
{isCancelled && <Badge variant="outline">Cancelled</Badge>}
|
||||||
? `Cancelled: ${displayName}`
|
</div>
|
||||||
: isError
|
{subtitle && (
|
||||||
? `Failed: ${displayName}`
|
<p
|
||||||
: displayName}
|
className={cn(
|
||||||
</p>
|
"text-xs truncate",
|
||||||
{isRunning && <p className="text-xs text-muted-foreground mt-0.5">Working…</p>}
|
isError ? "text-destructive/80" : "text-muted-foreground"
|
||||||
{cancelledReason && (
|
)}
|
||||||
<p className="text-xs text-muted-foreground mt-0.5 truncate">{cancelledReason}</p>
|
>
|
||||||
)}
|
{subtitle}
|
||||||
{errorReason && (
|
</p>
|
||||||
<p className="text-xs text-destructive/80 mt-0.5 truncate">{errorReason}</p>
|
)}
|
||||||
)}
|
</div>
|
||||||
</div>
|
</button>
|
||||||
|
</CollapsibleTrigger>
|
||||||
|
|
||||||
{!isRunning && (
|
{/*
|
||||||
<div className="shrink-0 text-muted-foreground">
|
Right-side controls. The Revert button is
|
||||||
{isExpanded ? (
|
visible whenever the matching action is
|
||||||
<ChevronDownIcon className="size-4" />
|
reversible — including the collapsed state —
|
||||||
) : (
|
but ``ToolCardRevertButton`` itself returns
|
||||||
<ChevronUpIcon className="size-4" />
|
``null`` while a tool is still running because
|
||||||
)}
|
no action-log row exists yet, so it doesn't
|
||||||
|
need an explicit ``isRunning`` gate here.
|
||||||
|
*/}
|
||||||
|
<div className="flex shrink-0 items-center gap-2 pl-2 pr-5">
|
||||||
|
<ToolCardRevertButton
|
||||||
|
toolCallId={toolCallId}
|
||||||
|
toolName={toolName}
|
||||||
|
langchainToolCallId={langchainToolCallId}
|
||||||
|
/>
|
||||||
|
<CollapsibleTrigger asChild>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
aria-label={isExpanded ? "Collapse details" : "Expand details"}
|
||||||
|
className={cn(
|
||||||
|
"flex size-7 shrink-0 items-center justify-center rounded-md",
|
||||||
|
"text-muted-foreground hover:bg-muted hover:text-foreground",
|
||||||
|
"focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset",
|
||||||
|
"disabled:cursor-default"
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<ChevronDownIcon
|
||||||
|
className={cn(
|
||||||
|
"size-4 transition-transform duration-200",
|
||||||
|
"group-data-[state=open]:rotate-180"
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
</CollapsibleTrigger>
|
||||||
</div>
|
</div>
|
||||||
)}
|
</div>
|
||||||
</button>
|
|
||||||
|
|
||||||
{isExpanded && !isRunning && (
|
{/*
|
||||||
<>
|
CollapsibleContent body — auto-open while streaming
|
||||||
<div className="mx-5 h-px bg-border/50" />
|
(see ``open`` prop above) so the live ``argsText``
|
||||||
<div className="px-5 py-3 space-y-3">
|
streams into the Inputs panel directly, no need for
|
||||||
{argsText && (
|
a separate "Live input" panel. Native
|
||||||
<div>
|
``overflow-auto`` instead of ``ScrollArea`` because
|
||||||
<p className="text-xs font-medium text-muted-foreground mb-1">Inputs</p>
|
Radix's Viewport can let content bleed past
|
||||||
<pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all">
|
``max-h-*`` in dynamic flex layouts. ``min-w-0`` on
|
||||||
{argsText}
|
the column wrappers guarantees ``break-all`` wraps
|
||||||
</pre>
|
correctly within the bounded ``max-w-lg`` Card.
|
||||||
|
*/}
|
||||||
|
<CollapsibleContent>
|
||||||
|
<Separator />
|
||||||
|
<div className="flex flex-col gap-3 px-5 py-3">
|
||||||
|
{(argsText || isRunning) && (
|
||||||
|
<div className="flex flex-col gap-1 min-w-0">
|
||||||
|
<p className="text-xs font-medium text-muted-foreground">Inputs</p>
|
||||||
|
<div className="max-h-48 overflow-auto rounded-md bg-muted/40">
|
||||||
|
{argsText ? (
|
||||||
|
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
||||||
|
{argsText}
|
||||||
|
</pre>
|
||||||
|
) : (
|
||||||
|
// Bridges the brief gap between
|
||||||
|
// ``tool-input-start`` (creates the
|
||||||
|
// card, ``argsText`` undefined) and
|
||||||
|
// the first ``tool-input-delta``.
|
||||||
|
<p className="px-3 py-2 text-xs italic text-muted-foreground">
|
||||||
|
Waiting for input…
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
{!isCancelled && result !== undefined && (
|
{!isCancelled && result !== undefined && (
|
||||||
<>
|
<>
|
||||||
<div className="h-px bg-border/30" />
|
<Separator />
|
||||||
<div>
|
<div className="flex flex-col gap-1 min-w-0">
|
||||||
<p className="text-xs font-medium text-muted-foreground mb-1">Result</p>
|
<p className="text-xs font-medium text-muted-foreground">Result</p>
|
||||||
<pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all">
|
<div className="max-h-64 overflow-auto rounded-md bg-muted/40">
|
||||||
{typeof result === "string" ? result : serializedResult}
|
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
||||||
</pre>
|
{typeof result === "string" ? result : serializedResult}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
<div className="flex justify-end">
|
|
||||||
<ToolCardRevertButton toolCallId={toolCallId} />
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
</>
|
</CollapsibleContent>
|
||||||
)}
|
</Collapsible>
|
||||||
</div>
|
</Card>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import {
|
||||||
addToolCall,
|
addToolCall,
|
||||||
appendReasoning,
|
appendReasoning,
|
||||||
appendText,
|
appendText,
|
||||||
|
appendToolInputDelta,
|
||||||
buildContentForUI,
|
buildContentForUI,
|
||||||
type ContentPartsState,
|
type ContentPartsState,
|
||||||
endReasoning,
|
endReasoning,
|
||||||
|
|
@ -188,6 +189,10 @@ export function FreeChatPage() {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
const scheduleFlush = () => batcher.schedule(flushMessages);
|
const scheduleFlush = () => batcher.schedule(flushMessages);
|
||||||
|
const forceFlush = () => {
|
||||||
|
scheduleFlush();
|
||||||
|
batcher.flush();
|
||||||
|
};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
for await (const parsed of readSSEStream(response)) {
|
for await (const parsed of readSSEStream(response)) {
|
||||||
|
|
@ -225,13 +230,20 @@ export function FreeChatPage() {
|
||||||
false,
|
false,
|
||||||
parsed.langchainToolCallId
|
parsed.langchainToolCallId
|
||||||
);
|
);
|
||||||
batcher.flush();
|
forceFlush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "tool-input-available":
|
case "tool-input-delta":
|
||||||
|
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
|
||||||
|
scheduleFlush();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "tool-input-available": {
|
||||||
|
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
|
||||||
if (toolCallIndices.has(parsed.toolCallId)) {
|
if (toolCallIndices.has(parsed.toolCallId)) {
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
args: parsed.input || {},
|
args: parsed.input || {},
|
||||||
|
argsText: finalArgsText,
|
||||||
langchainToolCallId: parsed.langchainToolCallId,
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -244,16 +256,20 @@ export function FreeChatPage() {
|
||||||
false,
|
false,
|
||||||
parsed.langchainToolCallId
|
parsed.langchainToolCallId
|
||||||
);
|
);
|
||||||
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
|
argsText: finalArgsText,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
batcher.flush();
|
forceFlush();
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
case "tool-output-available":
|
case "tool-output-available":
|
||||||
updateToolCall(contentPartsState, parsed.toolCallId, {
|
updateToolCall(contentPartsState, parsed.toolCallId, {
|
||||||
result: parsed.output,
|
result: parsed.output,
|
||||||
langchainToolCallId: parsed.langchainToolCallId,
|
langchainToolCallId: parsed.langchainToolCallId,
|
||||||
});
|
});
|
||||||
batcher.flush();
|
forceFlush();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "data-thinking-step": {
|
case "data-thinking-step": {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,13 @@
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Raw message from database (real-time sync)
|
* Raw message from database (real-time sync).
|
||||||
|
*
|
||||||
|
* ``turn_id`` is included so consumers (e.g. ``convertToThreadMessage``)
|
||||||
|
* can populate ``metadata.custom.chatTurnId`` on the
|
||||||
|
* ``ThreadMessageLike`` even after the live-collab Zero re-sync. The
|
||||||
|
* inline Revert button's ``(chat_turn_id, tool_name, position)``
|
||||||
|
* fallback in tool-fallback.tsx depends on it.
|
||||||
*/
|
*/
|
||||||
export const rawMessage = z.object({
|
export const rawMessage = z.object({
|
||||||
id: z.number(),
|
id: z.number(),
|
||||||
|
|
@ -10,6 +16,7 @@ export const rawMessage = z.object({
|
||||||
content: z.unknown(),
|
content: z.unknown(),
|
||||||
author_id: z.string().nullable(),
|
author_id: z.string().nullable(),
|
||||||
created_at: z.string(),
|
created_at: z.string(),
|
||||||
|
turn_id: z.string().nullable().optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export type RawMessage = z.infer<typeof rawMessage>;
|
export type RawMessage = z.infer<typeof rawMessage>;
|
||||||
|
|
|
||||||
416
surfsense_web/hooks/use-agent-actions-query.ts
Normal file
416
surfsense_web/hooks/use-agent-actions-query.ts
Normal file
|
|
@ -0,0 +1,416 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { type QueryClient, useQuery } from "@tanstack/react-query";
|
||||||
|
import { useCallback, useEffect, useMemo, useRef } from "react";
|
||||||
|
import {
|
||||||
|
type AgentAction,
|
||||||
|
type AgentActionListResponse,
|
||||||
|
agentActionsApiService,
|
||||||
|
} from "@/lib/apis/agent-actions-api.service";
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// DIAGNOSTIC LOGGING — gated behind a single switch. Flip ``RevertDebug``
|
||||||
|
// to ``true`` to trace the full SSE → cache → card → button pipeline in
|
||||||
|
// the browser console. Off by default so we don't spam production. The
|
||||||
|
// infrastructure stays in place because the underlying id-mismatch
|
||||||
|
// failure mode is rare-but-real and surfaces only at runtime.
|
||||||
|
// =============================================================================
|
||||||
|
const RevertDebug = false;
|
||||||
|
const dbg = (...args: unknown[]) => {
|
||||||
|
if (RevertDebug && typeof window !== "undefined") {
|
||||||
|
// eslint-disable-next-line no-console
|
||||||
|
console.log("[RevertDebug]", ...args);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Unified store for ``AgentActionLog`` rows scoped to one thread.
|
||||||
|
*
|
||||||
|
* Replaces the previous SSE side-channel atom mess
|
||||||
|
* (``agentActionByLcIdAtom`` / ``agentActionByToolCallIdAtom`` /
|
||||||
|
* ``agentActionsByChatTurnIdAtom``) and the standalone hydration hook.
|
||||||
|
* One react-query cache entry is now the single source of truth for:
|
||||||
|
*
|
||||||
|
* * the inline Revert button on every tool-call card
|
||||||
|
* * the per-turn "Revert turn" button under each assistant message
|
||||||
|
* * the edit-from-position pre-flight that decides whether to show
|
||||||
|
* the confirmation dialog
|
||||||
|
* * the agent-actions sheet
|
||||||
|
*
|
||||||
|
* The cache is hydrated by ``GET /threads/{id}/actions`` (sized to
|
||||||
|
* 200, the server max) and updated incrementally by helpers that turn
|
||||||
|
* SSE events / revert RPC responses into ``setQueryData`` mutations.
|
||||||
|
* That keeps the card and the sheet in lockstep on every code path —
|
||||||
|
* page reload, navigation, live stream, post-stream reversibility flip,
|
||||||
|
* and explicit revert clicks.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export const ACTION_LOG_PAGE_SIZE = 200;
|
||||||
|
|
||||||
|
/** Stable react-query key for the per-thread action list. */
|
||||||
|
export function agentActionsQueryKey(threadId: number | null) {
|
||||||
|
return threadId !== null
|
||||||
|
? (["agent-actions", threadId] as const)
|
||||||
|
: (["agent-actions", "none"] as const);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Subset of the SSE ``data-action-log`` payload we care about. */
|
||||||
|
export interface ActionLogSseEvent {
|
||||||
|
id: number;
|
||||||
|
lc_tool_call_id: string | null;
|
||||||
|
chat_turn_id: string | null;
|
||||||
|
tool_name: string;
|
||||||
|
reversible: boolean;
|
||||||
|
reverse_descriptor_present: boolean;
|
||||||
|
error: boolean;
|
||||||
|
created_at: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Append or upsert a freshly-emitted ``AgentActionLog`` row into the
|
||||||
|
* thread-scoped query cache.
|
||||||
|
*
|
||||||
|
* The SSE payload is a strict subset of ``AgentAction``; missing
|
||||||
|
* fields (``args``, ``reverse_descriptor``, ``user_id``) are filled
|
||||||
|
* with ``null`` placeholders. The next refetch (sheet open, user
|
||||||
|
* focus, route stale) backfills them — but the inline Revert button
|
||||||
|
* only reads the fields the SSE payload carries, so it lights up
|
||||||
|
* immediately.
|
||||||
|
*/
|
||||||
|
export function applyActionLogSse(
|
||||||
|
queryClient: QueryClient,
|
||||||
|
threadId: number,
|
||||||
|
searchSpaceId: number,
|
||||||
|
event: ActionLogSseEvent
|
||||||
|
): void {
|
||||||
|
dbg("applyActionLogSse: incoming SSE event", {
|
||||||
|
threadId,
|
||||||
|
searchSpaceId,
|
||||||
|
event,
|
||||||
|
});
|
||||||
|
queryClient.setQueryData<AgentActionListResponse>(
|
||||||
|
agentActionsQueryKey(threadId),
|
||||||
|
(prev) => {
|
||||||
|
const placeholder: AgentAction = {
|
||||||
|
id: event.id,
|
||||||
|
thread_id: threadId,
|
||||||
|
user_id: null,
|
||||||
|
search_space_id: searchSpaceId,
|
||||||
|
tool_name: event.tool_name,
|
||||||
|
args: null,
|
||||||
|
result_id: null,
|
||||||
|
reversible: event.reversible,
|
||||||
|
reverse_descriptor: event.reverse_descriptor_present ? {} : null,
|
||||||
|
error: event.error ? {} : null,
|
||||||
|
reverse_of: null,
|
||||||
|
reverted_by_action_id: null,
|
||||||
|
is_revert_action: false,
|
||||||
|
tool_call_id: event.lc_tool_call_id,
|
||||||
|
chat_turn_id: event.chat_turn_id,
|
||||||
|
created_at: event.created_at ?? new Date().toISOString(),
|
||||||
|
};
|
||||||
|
if (!prev) {
|
||||||
|
return {
|
||||||
|
items: [placeholder],
|
||||||
|
total: 1,
|
||||||
|
page: 0,
|
||||||
|
page_size: ACTION_LOG_PAGE_SIZE,
|
||||||
|
has_more: false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
const existingIdx = prev.items.findIndex((a) => a.id === event.id);
|
||||||
|
if (existingIdx >= 0) {
|
||||||
|
const merged = [...prev.items];
|
||||||
|
const existing = merged[existingIdx];
|
||||||
|
if (existing) {
|
||||||
|
merged[existingIdx] = {
|
||||||
|
...existing,
|
||||||
|
reversible: event.reversible,
|
||||||
|
tool_call_id: event.lc_tool_call_id ?? existing.tool_call_id,
|
||||||
|
chat_turn_id: event.chat_turn_id ?? existing.chat_turn_id,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
dbg("applyActionLogSse: merged into existing entry", {
|
||||||
|
id: event.id,
|
||||||
|
tool_call_id: merged[existingIdx]?.tool_call_id,
|
||||||
|
reversible: merged[existingIdx]?.reversible,
|
||||||
|
});
|
||||||
|
return { ...prev, items: merged };
|
||||||
|
}
|
||||||
|
dbg("applyActionLogSse: appended new placeholder", {
|
||||||
|
id: event.id,
|
||||||
|
tool_call_id: placeholder.tool_call_id,
|
||||||
|
tool_name: placeholder.tool_name,
|
||||||
|
reversible: placeholder.reversible,
|
||||||
|
cacheSizeAfter: prev.items.length + 1,
|
||||||
|
});
|
||||||
|
// REST returns newest-first — keep that ordering when
|
||||||
|
// the server eventually refetches by prepending.
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
items: [placeholder, ...prev.items],
|
||||||
|
total: prev.total + 1,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply a post-SAVEPOINT reversibility flip
|
||||||
|
* (``data-action-log-updated`` SSE event) to the cache.
|
||||||
|
*/
|
||||||
|
export function applyActionLogUpdatedSse(
|
||||||
|
queryClient: QueryClient,
|
||||||
|
threadId: number,
|
||||||
|
id: number,
|
||||||
|
reversible: boolean
|
||||||
|
): void {
|
||||||
|
dbg("applyActionLogUpdatedSse: reversibility flip", {
|
||||||
|
threadId,
|
||||||
|
id,
|
||||||
|
reversible,
|
||||||
|
});
|
||||||
|
queryClient.setQueryData<AgentActionListResponse>(
|
||||||
|
agentActionsQueryKey(threadId),
|
||||||
|
(prev) => {
|
||||||
|
if (!prev) {
|
||||||
|
dbg("applyActionLogUpdatedSse: NO prev cache for thread; flip dropped", {
|
||||||
|
threadId,
|
||||||
|
id,
|
||||||
|
});
|
||||||
|
return prev;
|
||||||
|
}
|
||||||
|
let mutated = false;
|
||||||
|
const items = prev.items.map((a) => {
|
||||||
|
if (a.id !== id) return a;
|
||||||
|
mutated = true;
|
||||||
|
return { ...a, reversible };
|
||||||
|
});
|
||||||
|
if (!mutated) {
|
||||||
|
dbg("applyActionLogUpdatedSse: id not in cache; flip dropped", {
|
||||||
|
threadId,
|
||||||
|
id,
|
||||||
|
cacheSize: prev.items.length,
|
||||||
|
cacheIds: prev.items.map((a) => a.id),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return mutated ? { ...prev, items } : prev;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optimistically mark ``id`` as reverted.
|
||||||
|
*
|
||||||
|
* Used by the inline / per-turn Revert button immediately after the
|
||||||
|
* server returns success so the UI flips to "Reverted" without
|
||||||
|
* waiting for a refetch. ``newActionId`` is the id of the new
|
||||||
|
* ``is_revert_action`` row the server inserted; pass ``null`` if the
|
||||||
|
* server didn't return it.
|
||||||
|
*/
|
||||||
|
export function markActionRevertedInCache(
|
||||||
|
queryClient: QueryClient,
|
||||||
|
threadId: number,
|
||||||
|
id: number,
|
||||||
|
newActionId: number | null
|
||||||
|
): void {
|
||||||
|
queryClient.setQueryData<AgentActionListResponse>(
|
||||||
|
agentActionsQueryKey(threadId),
|
||||||
|
(prev) => {
|
||||||
|
if (!prev) return prev;
|
||||||
|
let mutated = false;
|
||||||
|
const items = prev.items.map((a) => {
|
||||||
|
if (a.id !== id) return a;
|
||||||
|
mutated = true;
|
||||||
|
// ``-1`` is a sentinel meaning "we know it was reverted
|
||||||
|
// but the server didn't tell us the new row's id".
|
||||||
|
return {
|
||||||
|
...a,
|
||||||
|
reverted_by_action_id: newActionId ?? -1,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
return mutated ? { ...prev, items } : prev;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply a batch of revert results (per-turn revert response) to the
|
||||||
|
* cache. Anything in the ``reverted`` / ``already_reverted`` buckets
|
||||||
|
* gets its ``reverted_by_action_id`` set; other rows are left alone.
|
||||||
|
*/
|
||||||
|
export function applyRevertTurnResultsToCache(
|
||||||
|
queryClient: QueryClient,
|
||||||
|
threadId: number,
|
||||||
|
entries: Array<{ id: number; newActionId: number | null }>
|
||||||
|
): void {
|
||||||
|
if (entries.length === 0) return;
|
||||||
|
queryClient.setQueryData<AgentActionListResponse>(
|
||||||
|
agentActionsQueryKey(threadId),
|
||||||
|
(prev) => {
|
||||||
|
if (!prev) return prev;
|
||||||
|
const lookup = new Map(entries.map((e) => [e.id, e.newActionId]));
|
||||||
|
let mutated = false;
|
||||||
|
const items = prev.items.map((a) => {
|
||||||
|
if (!lookup.has(a.id)) return a;
|
||||||
|
mutated = true;
|
||||||
|
const newActionId = lookup.get(a.id) ?? null;
|
||||||
|
return { ...a, reverted_by_action_id: newActionId ?? -1 };
|
||||||
|
});
|
||||||
|
return mutated ? { ...prev, items } : prev;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read-side hook used by the card, the turn button, the sheet, and
|
||||||
|
* the edit-from-position pre-flight.
|
||||||
|
*
|
||||||
|
* Returns the raw query state plus convenience selectors so consumers
|
||||||
|
* don't reach into ``data.items`` directly. ``enabled`` is the only
|
||||||
|
* knob — pass ``false`` to keep the query dormant when the consumer
|
||||||
|
* doesn't yet have a thread id.
|
||||||
|
*/
|
||||||
|
export function useAgentActionsQuery(
|
||||||
|
threadId: number | null,
|
||||||
|
options: { enabled?: boolean } = {}
|
||||||
|
) {
|
||||||
|
const enabled = (options.enabled ?? true) && threadId !== null;
|
||||||
|
const query = useQuery({
|
||||||
|
queryKey: agentActionsQueryKey(threadId),
|
||||||
|
queryFn: async () => {
|
||||||
|
dbg("useAgentActionsQuery: REST fetch START", {
|
||||||
|
threadId,
|
||||||
|
pageSize: ACTION_LOG_PAGE_SIZE,
|
||||||
|
});
|
||||||
|
const res = await agentActionsApiService.listForThread(threadId as number, {
|
||||||
|
page: 0,
|
||||||
|
pageSize: ACTION_LOG_PAGE_SIZE,
|
||||||
|
});
|
||||||
|
dbg("useAgentActionsQuery: REST fetch DONE", {
|
||||||
|
threadId,
|
||||||
|
total: res.total,
|
||||||
|
returned: res.items.length,
|
||||||
|
items: res.items.map((a) => ({
|
||||||
|
id: a.id,
|
||||||
|
tool_name: a.tool_name,
|
||||||
|
tool_call_id: a.tool_call_id,
|
||||||
|
reversible: a.reversible,
|
||||||
|
reverted_by_action_id: a.reverted_by_action_id,
|
||||||
|
is_revert_action: a.is_revert_action,
|
||||||
|
})),
|
||||||
|
});
|
||||||
|
return res;
|
||||||
|
},
|
||||||
|
enabled,
|
||||||
|
staleTime: 15 * 1000,
|
||||||
|
});
|
||||||
|
|
||||||
|
const items = useMemo(() => query.data?.items ?? [], [query.data]);
|
||||||
|
|
||||||
|
// Index ``items`` once per change so the lookups below are O(1)
|
||||||
|
// instead of O(N) per card per render. With the cache sized to 200
|
||||||
|
// rows and many tool cards visible at once, the unindexed scan was
|
||||||
|
// the hottest path on every assistant text-delta. (Vercel React
|
||||||
|
// rule ``js-index-maps`` / ``js-set-map-lookups``.)
|
||||||
|
const byToolCallId = useMemo(() => {
|
||||||
|
const m = new Map<string, AgentAction>();
|
||||||
|
for (const a of items) {
|
||||||
|
if (a.tool_call_id) m.set(a.tool_call_id, a);
|
||||||
|
}
|
||||||
|
return m;
|
||||||
|
}, [items]);
|
||||||
|
|
||||||
|
// Pre-grouped + pre-sorted (oldest-first, the order the agent
|
||||||
|
// actually executed them in) so the (chat_turn_id, tool_name,
|
||||||
|
// position) fallback in ``tool-fallback.tsx`` is also O(1) per
|
||||||
|
// card. Excludes ``is_revert_action`` rows so the position index
|
||||||
|
// matches the agent's original execution order.
|
||||||
|
const byTurnAndTool = useMemo(() => {
|
||||||
|
const m = new Map<string, AgentAction[]>();
|
||||||
|
for (const a of items) {
|
||||||
|
if (!a.chat_turn_id || a.is_revert_action) continue;
|
||||||
|
const key = `${a.chat_turn_id}::${a.tool_name}`;
|
||||||
|
const bucket = m.get(key);
|
||||||
|
if (bucket) bucket.push(a);
|
||||||
|
else m.set(key, [a]);
|
||||||
|
}
|
||||||
|
for (const bucket of m.values()) {
|
||||||
|
bucket.sort(
|
||||||
|
(a, b) =>
|
||||||
|
new Date(a.created_at).getTime() - new Date(b.created_at).getTime()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return m;
|
||||||
|
}, [items]);
|
||||||
|
|
||||||
|
// Snapshot the cache shape when its size changes — easiest way to
|
||||||
|
// spot when the cache is empty or stale at the moment a card
|
||||||
|
// mounts. Tracked on a ref so we don't re-run the diff on
|
||||||
|
// reference-equal cache reads.
|
||||||
|
const lastSnapshotRef = useRef<{ threadId: number | null; size: number } | null>(null);
|
||||||
|
useEffect(() => {
|
||||||
|
const last = lastSnapshotRef.current;
|
||||||
|
if (!last || last.threadId !== threadId || last.size !== items.length) {
|
||||||
|
dbg("useAgentActionsQuery: cache snapshot", {
|
||||||
|
threadId,
|
||||||
|
enabled,
|
||||||
|
itemCount: items.length,
|
||||||
|
itemKeys: items.slice(0, 8).map((a) => ({
|
||||||
|
id: a.id,
|
||||||
|
tool_name: a.tool_name,
|
||||||
|
tool_call_id: a.tool_call_id,
|
||||||
|
chat_turn_id: a.chat_turn_id,
|
||||||
|
reversible: a.reversible,
|
||||||
|
})),
|
||||||
|
});
|
||||||
|
lastSnapshotRef.current = { threadId, size: items.length };
|
||||||
|
}
|
||||||
|
}, [threadId, enabled, items]);
|
||||||
|
|
||||||
|
const findByToolCallId = useCallback(
|
||||||
|
(toolCallId: string | null | undefined): AgentAction | null => {
|
||||||
|
if (!toolCallId) return null;
|
||||||
|
const found = byToolCallId.get(toolCallId) ?? null;
|
||||||
|
if (!found && items.length > 0) {
|
||||||
|
dbg("findByToolCallId: MISS", {
|
||||||
|
queriedToolCallId: toolCallId,
|
||||||
|
itemCount: items.length,
|
||||||
|
availableToolCallIds: Array.from(byToolCallId.keys()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return found;
|
||||||
|
},
|
||||||
|
[byToolCallId, items.length]
|
||||||
|
);
|
||||||
|
|
||||||
|
const findByChatTurnId = useCallback(
|
||||||
|
(chatTurnId: string | null | undefined): AgentAction[] => {
|
||||||
|
if (!chatTurnId) return [];
|
||||||
|
// Per-turn aggregation is uncommon enough (only the
|
||||||
|
// "Revert turn" button uses it) that re-scanning is fine;
|
||||||
|
// indexing it would just bloat memory.
|
||||||
|
return items.filter((a) => a.chat_turn_id === chatTurnId);
|
||||||
|
},
|
||||||
|
[items]
|
||||||
|
);
|
||||||
|
|
||||||
|
const findByChatTurnAndTool = useCallback(
|
||||||
|
(
|
||||||
|
chatTurnId: string | null | undefined,
|
||||||
|
toolName: string | null | undefined
|
||||||
|
): AgentAction[] => {
|
||||||
|
if (!chatTurnId || !toolName) return [];
|
||||||
|
return byTurnAndTool.get(`${chatTurnId}::${toolName}`) ?? [];
|
||||||
|
},
|
||||||
|
[byTurnAndTool]
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
...query,
|
||||||
|
items,
|
||||||
|
findByToolCallId,
|
||||||
|
findByChatTurnId,
|
||||||
|
findByChatTurnAndTool,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
@ -31,6 +31,14 @@ export function useMessagesSync(
|
||||||
content: msg.content,
|
content: msg.content,
|
||||||
author_id: msg.authorId ?? null,
|
author_id: msg.authorId ?? null,
|
||||||
created_at: new Date(msg.createdAt).toISOString(),
|
created_at: new Date(msg.createdAt).toISOString(),
|
||||||
|
// Forward the per-turn correlation id so post-stream Zero
|
||||||
|
// re-syncs preserve ``metadata.custom.chatTurnId`` on the
|
||||||
|
// converted ``ThreadMessageLike``. Without this the inline
|
||||||
|
// Revert button's ``(chat_turn_id, tool_name, position)``
|
||||||
|
// fallback breaks the moment Zero overwrites the messages
|
||||||
|
// state after a live stream completes (see
|
||||||
|
// ``handleSyncedMessagesUpdate`` in the chat page).
|
||||||
|
turn_id: msg.turnId ?? null,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
onMessagesUpdateRef.current(mapped);
|
onMessagesUpdateRef.current(mapped);
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,23 @@ export type ContentPart =
|
||||||
toolName: string;
|
toolName: string;
|
||||||
args: Record<string, unknown>;
|
args: Record<string, unknown>;
|
||||||
result?: unknown;
|
result?: unknown;
|
||||||
|
/**
|
||||||
|
* Live / finalized JSON text for the tool's input arguments.
|
||||||
|
*
|
||||||
|
* - During streaming: accumulated partial JSON text from
|
||||||
|
* ``tool-input-delta`` events (may be invalid JSON
|
||||||
|
* mid-stream). assistant-ui's argsText parser tolerates
|
||||||
|
* invalid JSON gracefully (changelog 0.7.32 / 0.7.78).
|
||||||
|
* - On completion (``tool-input-available``): replaced with
|
||||||
|
* ``JSON.stringify(input, null, 2)`` so the post-stream
|
||||||
|
* card renders pretty-printed JSON instead of the
|
||||||
|
* model's possibly-fragmented formatting.
|
||||||
|
*
|
||||||
|
* Per assistant-ui ``ThreadMessageLike`` precedence
|
||||||
|
* (changelog 0.11.6 ``d318c83``), when ``argsText`` is
|
||||||
|
* supplied it wins over ``JSON.stringify(args)``.
|
||||||
|
*/
|
||||||
|
argsText?: string;
|
||||||
/**
|
/**
|
||||||
* Authoritative LangChain ``tool_call.id`` propagated by the backend
|
* Authoritative LangChain ``tool_call.id`` propagated by the backend
|
||||||
* via ``langchainToolCallId`` on tool-input-start/available and
|
* via ``langchainToolCallId`` on tool-input-start/available and
|
||||||
|
|
@ -282,12 +299,22 @@ export function findToolCallIdByLcId(
|
||||||
export function updateToolCall(
|
export function updateToolCall(
|
||||||
state: ContentPartsState,
|
state: ContentPartsState,
|
||||||
toolCallId: string,
|
toolCallId: string,
|
||||||
update: { args?: Record<string, unknown>; result?: unknown; langchainToolCallId?: string }
|
update: {
|
||||||
|
args?: Record<string, unknown>;
|
||||||
|
argsText?: string;
|
||||||
|
result?: unknown;
|
||||||
|
langchainToolCallId?: string;
|
||||||
|
}
|
||||||
): void {
|
): void {
|
||||||
const index = state.toolCallIndices.get(toolCallId);
|
const index = state.toolCallIndices.get(toolCallId);
|
||||||
if (index !== undefined && state.contentParts[index]?.type === "tool-call") {
|
if (index !== undefined && state.contentParts[index]?.type === "tool-call") {
|
||||||
const tc = state.contentParts[index] as ContentPart & { type: "tool-call" };
|
const tc = state.contentParts[index] as ContentPart & { type: "tool-call" };
|
||||||
if (update.args) tc.args = update.args;
|
if (update.args) tc.args = update.args;
|
||||||
|
// ``!== undefined`` (NOT a truthy check): an explicit empty
|
||||||
|
// string CAN clear, and a finalization with
|
||||||
|
// ``JSON.stringify({}, null, 2) === "{}"`` (truthy but
|
||||||
|
// represents an empty-input call) still applies.
|
||||||
|
if (update.argsText !== undefined) tc.argsText = update.argsText;
|
||||||
if (update.result !== undefined) tc.result = update.result;
|
if (update.result !== undefined) tc.result = update.result;
|
||||||
// Only backfill langchainToolCallId if not already set — the
|
// Only backfill langchainToolCallId if not already set — the
|
||||||
// authoritative ``on_tool_end`` value should override an earlier
|
// authoritative ``on_tool_end`` value should override an earlier
|
||||||
|
|
@ -299,6 +326,25 @@ export function updateToolCall(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Append a streamed args-delta chunk to the active tool call's
|
||||||
|
* ``argsText``. No-ops when no card has been registered yet for the
|
||||||
|
* given ``toolCallId`` (the matching ``tool-input-start`` either lost
|
||||||
|
* the wire race or this id never had a card — either way the deltas
|
||||||
|
* have nowhere safe to land).
|
||||||
|
*/
|
||||||
|
export function appendToolInputDelta(
|
||||||
|
state: ContentPartsState,
|
||||||
|
toolCallId: string,
|
||||||
|
delta: string
|
||||||
|
): void {
|
||||||
|
const idx = state.toolCallIndices.get(toolCallId);
|
||||||
|
if (idx === undefined) return;
|
||||||
|
const tc = state.contentParts[idx];
|
||||||
|
if (tc?.type !== "tool-call") return;
|
||||||
|
tc.argsText = (tc.argsText ?? "") + delta;
|
||||||
|
}
|
||||||
|
|
||||||
function _hasInterruptResult(part: ContentPart): boolean {
|
function _hasInterruptResult(part: ContentPart): boolean {
|
||||||
if (part.type !== "tool-call") return false;
|
if (part.type !== "tool-call") return false;
|
||||||
const r = (part as { result?: unknown }).result;
|
const r = (part as { result?: unknown }).result;
|
||||||
|
|
@ -371,6 +417,18 @@ export type SSEEvent =
|
||||||
/** Authoritative LangChain ``tool_call.id``. Optional. */
|
/** Authoritative LangChain ``tool_call.id``. Optional. */
|
||||||
langchainToolCallId?: string;
|
langchainToolCallId?: string;
|
||||||
}
|
}
|
||||||
|
| {
|
||||||
|
/**
|
||||||
|
* Live tool-call argument delta. Concatenated into
|
||||||
|
* ``argsText`` on the matching ``tool-call`` content part
|
||||||
|
* by ``appendToolInputDelta``. parity_v2 only — the legacy
|
||||||
|
* code path emits ``tool-input-available`` without prior
|
||||||
|
* deltas.
|
||||||
|
*/
|
||||||
|
type: "tool-input-delta";
|
||||||
|
toolCallId: string;
|
||||||
|
inputTextDelta: string;
|
||||||
|
}
|
||||||
| {
|
| {
|
||||||
type: "tool-input-available";
|
type: "tool-input-available";
|
||||||
toolCallId: string;
|
toolCallId: string;
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,13 @@ export const newChatMessageTable = table("new_chat_messages")
|
||||||
threadId: number().from("thread_id"),
|
threadId: number().from("thread_id"),
|
||||||
authorId: string().optional().from("author_id"),
|
authorId: string().optional().from("author_id"),
|
||||||
createdAt: number().from("created_at"),
|
createdAt: number().from("created_at"),
|
||||||
|
// Per-turn correlation id sourced from ``configurable.turn_id``
|
||||||
|
// at streaming time. Required by the inline Revert button's
|
||||||
|
// (chat_turn_id, tool_name, position) fallback in tool-fallback.tsx
|
||||||
|
// — without it the live-collab Zero sync would clobber the
|
||||||
|
// metadata we set during streaming and the button would vanish
|
||||||
|
// the moment Zero re-syncs after the stream finishes.
|
||||||
|
turnId: string().optional().from("turn_id"),
|
||||||
})
|
})
|
||||||
.primaryKey("id");
|
.primaryKey("id");
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue