Merge remote-tracking branch 'upstream/dev' into feat/split-auto-free-premium

This commit is contained in:
Anish Sarkar 2026-04-30 16:23:05 +05:30
commit 872065f90d
15 changed files with 1857 additions and 545 deletions

View file

@ -599,8 +599,17 @@ class VercelStreamingService:
Format the start of tool input streaming. Format the start of tool input streaming.
Args: Args:
tool_call_id: The unique tool call identifier (synthetic, derived tool_call_id: The unique tool call identifier. May be EITHER the
from LangGraph ``run_id`` so the frontend has a stable card id). synthetic ``call_<run_id>`` id derived from LangGraph
``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2``
OFF, or the unmatched-fallback path under parity_v2) OR
the authoritative LangChain ``tool_call.id`` (parity_v2
path: when the provider streams ``tool_call_chunks`` we
register the ``index`` and reuse the lc-id as the card
id so live ``tool-input-delta`` events can be routed
without a downstream join). Either way, the same id is
preserved across ``tool-input-start`` / ``-delta`` /
``-available`` / ``tool-output-available`` for one call.
tool_name: The name of the tool being called. tool_name: The name of the tool being called.
langchain_tool_call_id: Optional authoritative LangChain langchain_tool_call_id: Optional authoritative LangChain
``tool_call.id``. When set, surfaces as ``tool_call.id``. When set, surfaces as

View file

@ -473,6 +473,42 @@ def _emit_stream_terminal_error(
return streaming_service.format_error(message, error_code=error_code) return streaming_service.format_error(message, error_code=error_code)
def _legacy_match_lc_id(
pending_tool_call_chunks: list[dict[str, Any]],
tool_name: str,
run_id: str,
lc_tool_call_id_by_run: dict[str, str],
) -> str | None:
"""Best-effort match a buffered ``tool_call_chunk`` to a tool name.
Pure extract of the legacy in-line match used at ``on_tool_start`` for
parity_v2-OFF and unmatched (chunk path didn't register an index for
this call) tools. Pops the next id-bearing chunk whose ``name``
matches ``tool_name`` (or any id-bearing chunk as a fallback) and
returns its id. Mutates ``pending_tool_call_chunks`` and
``lc_tool_call_id_by_run`` in place.
"""
matched_idx: int | None = None
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("name") == tool_name and tcc.get("id"):
matched_idx = idx
break
if matched_idx is None:
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("id"):
matched_idx = idx
break
if matched_idx is None:
return None
matched = pending_tool_call_chunks.pop(matched_idx)
candidate = matched.get("id")
if isinstance(candidate, str) and candidate:
if run_id:
lc_tool_call_id_by_run[run_id] = candidate
return candidate
return None
async def _stream_agent_events( async def _stream_agent_events(
agent: Any, agent: Any,
config: dict[str, Any], config: dict[str, Any],
@ -538,10 +574,28 @@ async def _stream_agent_events(
# ``tool_call_chunks`` from ``on_chat_model_stream``, key them by # ``tool_call_chunks`` from ``on_chat_model_stream``, key them by
# name, and pop the next unconsumed entry at ``on_tool_start``. The # name, and pop the next unconsumed entry at ``on_tool_start``. The
# authoritative id is later filled in at ``on_tool_end`` from # authoritative id is later filled in at ``on_tool_end`` from
# ``ToolMessage.tool_call_id``. # ``ToolMessage.tool_call_id``. Under parity_v2 we ALSO short-circuit
# this list for chunks that already registered into ``index_to_meta``
# below — so this list is reserved for the parity_v2-OFF / unmatched
# fallback path only and never re-pops a chunk we already streamed.
pending_tool_call_chunks: list[dict[str, Any]] = [] pending_tool_call_chunks: list[dict[str, Any]] = []
lc_tool_call_id_by_run: dict[str, str] = {} lc_tool_call_id_by_run: dict[str, str] = {}
# parity_v2 only: live tool-call argument streaming. ``index_to_meta``
# is keyed by the chunk's ``index`` field — LangChain
# ``ToolCallChunk``s for the same call share an index but only the
# first chunk carries id+name (subsequent ones are id=None,
# name=None, args="<delta>"). We register an index when both id and
# name are observed on a chunk (per ToolCallChunk semantics they
# arrive together on the first chunk), then route every later chunk
# at that index to the same ``ui_id`` as a ``tool-input-delta``.
# ``ui_tool_call_id_by_run`` maps LangGraph ``run_id`` to the
# ``ui_id`` used for that call's ``tool-input-start`` so the matching
# ``tool-output-available`` (emitted from ``on_tool_end``) lands on
# the same card.
index_to_meta: dict[int, dict[str, str]] = {}
ui_tool_call_id_by_run: dict[str, str] = {}
# Per-tool-end mutable cache for the LangChain tool_call_id resolved # Per-tool-end mutable cache for the LangChain tool_call_id resolved
# at ``on_tool_end``. ``_emit_tool_output`` reads this so every # at ``on_tool_end``. ``_emit_tool_output`` reads this so every
# ``format_tool_output_available`` call automatically carries the # ``format_tool_output_available`` call automatically carries the
@ -587,13 +641,6 @@ async def _stream_agent_events(
continue continue
parts = _extract_chunk_parts(chunk) parts = _extract_chunk_parts(chunk)
# Accumulate any tool_call_chunks for best-effort
# correlation with ``on_tool_start`` below. We don't emit
# anything here; the matching is done at tool-start time.
if parity_v2 and parts["tool_call_chunks"]:
for tcc in parts["tool_call_chunks"]:
pending_tool_call_chunks.append(tcc)
reasoning_delta = parts["reasoning"] reasoning_delta = parts["reasoning"]
text_delta = parts["text"] text_delta = parts["text"]
@ -639,6 +686,71 @@ async def _stream_agent_events(
yield streaming_service.format_text_delta(current_text_id, text_delta) yield streaming_service.format_text_delta(current_text_id, text_delta)
accumulated_text += text_delta accumulated_text += text_delta
# Live tool-call argument streaming. Runs AFTER text/reasoning
# processing so chunks containing both stay in their natural
# wire order (text → text-end → tool-input-start). Active
# text/reasoning are closed inside the registration branch
# before ``tool-input-start`` so the frontend sees a clean
# part boundary even when providers interleave.
if parity_v2 and parts["tool_call_chunks"]:
for tcc in parts["tool_call_chunks"]:
idx = tcc.get("index")
# Register this index when we first see id+name
# TOGETHER. Per LangChain ToolCallChunk semantics the
# first chunk for a tool call carries both fields
# together; later chunks have id=None, name=None and
# only ``args``. Requiring BOTH keeps wire
# ``tool-input-start`` always carrying a real
# toolName (assistant-ui's typed tool-part dispatch
# keys off it).
if idx is not None and idx not in index_to_meta:
lc_id = tcc.get("id")
name = tcc.get("name")
if lc_id and name:
ui_id = lc_id
# Close active text/reasoning so wire
# ordering stays clean even on providers
# that interleave text and tool-call chunks
# within the same stream window.
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
current_text_id = None
if current_reasoning_id is not None:
yield streaming_service.format_reasoning_end(
current_reasoning_id
)
current_reasoning_id = None
index_to_meta[idx] = {
"ui_id": ui_id,
"lc_id": lc_id,
"name": name,
}
yield streaming_service.format_tool_input_start(
ui_id,
name,
langchain_tool_call_id=lc_id,
)
# Emit args delta for any chunk at a registered
# index (including idless continuations). Once an
# index is owned by ``index_to_meta`` we DO NOT
# append to ``pending_tool_call_chunks`` — that list
# is reserved for the parity_v2-OFF / unmatched
# fallback path so it never re-pops chunks already
# consumed here (skip-append).
meta = index_to_meta.get(idx) if idx is not None else None
if meta:
args_chunk = tcc.get("args") or ""
if args_chunk:
yield streaming_service.format_tool_input_delta(
meta["ui_id"], args_chunk
)
else:
pending_tool_call_chunks.append(tcc)
elif event_type == "on_tool_start": elif event_type == "on_tool_start":
active_tool_depth += 1 active_tool_depth += 1
tool_name = event.get("name", "unknown_tool") tool_name = event.get("name", "unknown_tool")
@ -969,44 +1081,65 @@ async def _stream_agent_events(
status="in_progress", status="in_progress",
) )
tool_call_id = ( # Resolve the card identity. If the chunk-emission loop
f"call_{run_id[:32]}" # already registered an ``index`` for this tool call (parity_v2
if run_id # path), reuse the same ui_id so the card sees:
else streaming_service.generate_tool_call_id() # tool-input-start → deltas… → tool-input-available →
) # tool-output-available all keyed by lc_id. Otherwise fall
# back to the synthetic ``call_<run_id>`` id and the legacy
# Best-effort attach the LangChain ``tool_call_id``. We # best-effort match against ``pending_tool_call_chunks``.
# pop the first chunk in ``pending_tool_call_chunks`` whose matched_meta: dict[str, str] | None = None
# name matches; if none match (the chunked args may not yet if parity_v2:
# carry a ``name`` field, or the model skipped the chunked # FIFO over indices 0,1,2…; first unassigned same-name
# form) we leave ``langchainToolCallId`` unset for now and # match wins. Handles parallel same-name calls (e.g. two
# fill it in authoritatively at ``on_tool_end`` from # write_file calls) deterministically as long as the
# ``ToolMessage.tool_call_id``. # model interleaves on_tool_start in the same order it
langchain_tool_call_id: str | None = None # streamed the args.
if parity_v2 and pending_tool_call_chunks: taken_ui_ids = set(ui_tool_call_id_by_run.values())
matched_idx: int | None = None for meta in index_to_meta.values():
for idx, tcc in enumerate(pending_tool_call_chunks): if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids:
if tcc.get("name") == tool_name and tcc.get("id"): matched_meta = meta
matched_idx = idx
break break
if matched_idx is None:
for idx, tcc in enumerate(pending_tool_call_chunks):
if tcc.get("id"):
matched_idx = idx
break
if matched_idx is not None:
matched = pending_tool_call_chunks.pop(matched_idx)
candidate = matched.get("id")
if isinstance(candidate, str) and candidate:
langchain_tool_call_id = candidate
if run_id:
lc_tool_call_id_by_run[run_id] = candidate
yield streaming_service.format_tool_input_start( tool_call_id: str
tool_call_id, langchain_tool_call_id: str | None = None
tool_name, if matched_meta is not None:
langchain_tool_call_id=langchain_tool_call_id, tool_call_id = matched_meta["ui_id"]
) langchain_tool_call_id = matched_meta["lc_id"]
# ``tool-input-start`` already fired during chunk
# emission — skip the duplicate. No pruning is needed
# because the chunk-emission loop intentionally never
# appends registered-index chunks to
# ``pending_tool_call_chunks`` (skip-append).
if run_id:
lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"]
else:
tool_call_id = (
f"call_{run_id[:32]}"
if run_id
else streaming_service.generate_tool_call_id()
)
# Legacy fallback: parity_v2 OFF, or parity_v2 ON but the
# provider didn't stream tool_call_chunks for this call
# (no index registered). Run the existing best-effort
# match BEFORE emitting start so we still attach an
# authoritative ``langchainToolCallId`` when possible.
if parity_v2:
langchain_tool_call_id = _legacy_match_lc_id(
pending_tool_call_chunks,
tool_name,
run_id,
lc_tool_call_id_by_run,
)
yield streaming_service.format_tool_input_start(
tool_call_id,
tool_name,
langchain_tool_call_id=langchain_tool_call_id,
)
if run_id:
ui_tool_call_id_by_run[run_id] = tool_call_id
# Sanitize tool_input: strip runtime-injected non-serializable # Sanitize tool_input: strip runtime-injected non-serializable
# values (e.g. LangChain ToolRuntime) before sending over SSE. # values (e.g. LangChain ToolRuntime) before sending over SSE.
if isinstance(tool_input, dict): if isinstance(tool_input, dict):
@ -1059,7 +1192,15 @@ async def _stream_agent_events(
result.write_succeeded = True result.write_succeeded = True
result.verification_succeeded = True result.verification_succeeded = True
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown" # Look up the SAME card id used at on_tool_start (either the
# parity_v2 lc-id-derived ui_id or the legacy synthetic
# ``call_<run_id>``) so the output event always lands on the
# same card as start/delta/available. Fallback preserves the
# legacy synthetic shape for parity_v2-OFF / unknown-run paths.
tool_call_id = ui_tool_call_id_by_run.get(
run_id,
f"call_{run_id[:32]}" if run_id else "call_unknown",
)
original_step_id = tool_step_ids.get( original_step_id = tool_step_ids.get(
run_id, f"{step_prefix}-unknown-{run_id[:8]}" run_id, f"{step_prefix}-unknown-{run_id[:8]}"
) )
@ -1070,17 +1211,22 @@ async def _stream_agent_events(
# at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``) # at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``)
# if the output isn't a ToolMessage. The value is stored in # if the output isn't a ToolMessage. The value is stored in
# ``current_lc_tool_call_id`` so ``_emit_tool_output`` # ``current_lc_tool_call_id`` so ``_emit_tool_output``
# picks it up for every output emit below. Stays None when # picks it up for every output emit below.
# parity_v2 is off so legacy emit paths are untouched. #
# Emitted in BOTH parity_v2 and legacy modes: the chat tool
# card needs the LangChain id to match against the
# ``data-action-log`` SSE event (keyed by ``lc_tool_call_id``)
# so the inline Revert button can light up. Reading
# ``raw_output.tool_call_id`` is a cheap, non-mutating attribute
# access that is safe regardless of feature-flag state.
current_lc_tool_call_id["value"] = None current_lc_tool_call_id["value"] = None
if parity_v2: authoritative = getattr(raw_output, "tool_call_id", None)
authoritative = getattr(raw_output, "tool_call_id", None) if isinstance(authoritative, str) and authoritative:
if isinstance(authoritative, str) and authoritative: current_lc_tool_call_id["value"] = authoritative
current_lc_tool_call_id["value"] = authoritative if run_id:
if run_id: lc_tool_call_id_by_run[run_id] = authoritative
lc_tool_call_id_by_run[run_id] = authoritative elif run_id and run_id in lc_tool_call_id_by_run:
elif run_id and run_id in lc_tool_call_id_by_run: current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id]
if tool_name == "read_file": if tool_name == "read_file":
yield streaming_service.format_thinking_step( yield streaming_service.format_thinking_step(

View file

@ -183,3 +183,46 @@ class TestDefensive:
assert out["text"] == "" assert out["text"] == ""
assert out["reasoning"] == "" assert out["reasoning"] == ""
assert out["tool_call_chunks"] == [] assert out["tool_call_chunks"] == []
class TestIdlessContinuationChunks:
"""Per LangChain ``ToolCallChunk`` semantics, the FIRST chunk for a
tool call carries id+name; later chunks for the same call have
``id=None, name=None`` and only ``args`` + ``index``. Live tool-call
argument streaming relies on those idless continuation chunks
flowing through ``_extract_chunk_parts`` UNTOUCHED so the upstream
chunk-emission loop can still route them by ``index``.
"""
def test_idless_continuation_chunk_preserved_verbatim(self) -> None:
chunk = _FakeChunk(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
]
)
out = _extract_chunk_parts(chunk)
assert len(out["tool_call_chunks"]) == 1
tcc = out["tool_call_chunks"][0]
assert tcc.get("id") is None
assert tcc.get("name") is None
assert tcc.get("args") == '_path":"/x"}'
assert tcc.get("index") == 0
def test_first_then_idless_sequence_preserves_index(self) -> None:
"""Both chunks for the same call share an ``index`` key — the
index-routing loop in ``stream_new_chat`` depends on it."""
first = _FakeChunk(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
]
)
cont = _FakeChunk(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
]
)
out_first = _extract_chunk_parts(first)
out_cont = _extract_chunk_parts(cont)
assert out_first["tool_call_chunks"][0]["index"] == 0
assert out_cont["tool_call_chunks"][0]["index"] == 0
assert out_cont["tool_call_chunks"][0].get("id") is None

View file

@ -0,0 +1,527 @@
"""Unit tests for live tool-call argument streaming.
Pins the wire format that ``_stream_agent_events`` emits when
``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start``
``tool-input-delta``... ``tool-input-available`` ``tool-output-available``
all keyed by the same LangChain ``tool_call.id``.
Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to
``_stream_agent_events`` so we exercise them via the public wire output.
These tests also lock in the legacy / parity_v2-OFF behaviour so the
synthetic ``call_<run_id>`` shape stays stable for older clients.
"""
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from typing import Any
import pytest
import app.tasks.chat.stream_new_chat as stream_module
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_new_chat import (
StreamResult,
_legacy_match_lc_id,
_stream_agent_events,
)
pytestmark = pytest.mark.unit
@dataclass
class _FakeChunk:
"""Minimal stand-in for ``AIMessageChunk``."""
content: Any = ""
additional_kwargs: dict[str, Any] = field(default_factory=dict)
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
@dataclass
class _FakeToolMessage:
"""Stand-in for ``ToolMessage`` returned by ``on_tool_end``."""
content: Any
tool_call_id: str | None = None
class _FakeAgentState:
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
def __init__(self) -> None:
# Empty values keeps the cloud-fallback safety-net branch a no-op,
# and an empty ``tasks`` list keeps the post-stream interrupt
# check a no-op too.
self.values: dict[str, Any] = {}
self.tasks: list[Any] = []
class _FakeAgent:
"""Replays a list of ``astream_events`` events."""
def __init__(self, events: list[dict[str, Any]]) -> None:
self._events = events
async def astream_events( # type: ignore[no-untyped-def]
self, _input_data: Any, *, config: dict[str, Any], version: str
) -> AsyncGenerator[dict[str, Any], None]:
del config, version # unused, contract-compatible
for ev in self._events:
yield ev
async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState:
# Called once after astream_events drains so the cloud-fallback
# safety net can inspect staged filesystem work. The fake stays
# empty so the safety net is a no-op.
return _FakeAgentState()
def _model_stream(
*,
text: str = "",
reasoning: str = "",
tool_call_chunks: list[dict[str, Any]] | None = None,
tags: list[str] | None = None,
) -> dict[str, Any]:
return (
{
"event": "on_chat_model_stream",
"tags": tags or [],
"data": {
"chunk": _FakeChunk(
content=text,
tool_call_chunks=list(tool_call_chunks or []),
)
},
# reasoning piggybacks via additional_kwargs path; if needed,
# override content to a typed-block list. Most tests just check
# tool_call_chunks routing so this is fine.
}
if not reasoning
else {
"event": "on_chat_model_stream",
"tags": tags or [],
"data": {
"chunk": _FakeChunk(
content=text,
additional_kwargs={"reasoning_content": reasoning},
tool_call_chunks=list(tool_call_chunks or []),
)
},
}
)
def _tool_start(
*,
name: str,
run_id: str,
input_payload: dict[str, Any] | None = None,
) -> dict[str, Any]:
return {
"event": "on_tool_start",
"name": name,
"run_id": run_id,
"data": {"input": input_payload or {}},
}
def _tool_end(
*,
name: str,
run_id: str,
tool_call_id: str | None = None,
output: Any = "ok",
) -> dict[str, Any]:
return {
"event": "on_tool_end",
"name": name,
"run_id": run_id,
"data": {
"output": _FakeToolMessage(
content=json.dumps(output) if not isinstance(output, str) else output,
tool_call_id=tool_call_id,
)
},
}
@pytest.fixture
def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
stream_module,
"get_flags",
lambda: AgentFeatureFlags(enable_stream_parity_v2=True),
)
@pytest.fixture
def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
stream_module,
"get_flags",
lambda: AgentFeatureFlags(enable_stream_parity_v2=False),
)
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Run ``_stream_agent_events`` against a fake agent and return the
SSE payloads (parsed JSON) it yielded.
"""
agent = _FakeAgent(events)
service = VercelStreamingService()
result = StreamResult()
config = {"configurable": {"thread_id": "test-thread"}}
sse_lines: list[str] = []
async for sse in _stream_agent_events(
agent, config, {}, service, result, step_prefix="thinking"
):
sse_lines.append(sse)
parsed: list[dict[str, Any]] = []
for line in sse_lines:
if not line.startswith("data: "):
continue
body = line[len("data: ") :].rstrip("\n")
if not body or body == "[DONE]":
continue
try:
parsed.append(json.loads(body))
except json.JSONDecodeError:
continue
return parsed
def _types(payloads: list[dict[str, Any]]) -> list[str]:
return [p.get("type", "?") for p in payloads]
def _of_type(payloads: list[dict[str, Any]], type_name: str) -> list[dict[str, Any]]:
return [p for p in payloads if p.get("type") == type_name]
# ---------------------------------------------------------------------------
# Helper: ``_legacy_match_lc_id`` is a pure refactor; assert behaviour.
# ---------------------------------------------------------------------------
class TestLegacyMatch:
def test_pops_first_id_bearing_chunk_with_matching_name(self) -> None:
chunks: list[dict[str, Any]] = [
{"id": "x1", "name": "ls"},
{"id": "y1", "name": "write_file"},
]
runs: dict[str, str] = {}
result = _legacy_match_lc_id(chunks, "write_file", "run-1", runs)
assert result == "y1"
assert chunks == [{"id": "x1", "name": "ls"}]
assert runs == {"run-1": "y1"}
def test_falls_back_to_any_id_bearing_when_name_mismatches(self) -> None:
chunks: list[dict[str, Any]] = [{"id": "anon", "name": None}]
runs: dict[str, str] = {}
out = _legacy_match_lc_id(chunks, "ls", "run-2", runs)
assert out == "anon"
assert chunks == []
def test_returns_none_when_no_id_bearing_chunk(self) -> None:
chunks: list[dict[str, Any]] = [{"id": None, "name": None}]
runs: dict[str, str] = {}
assert _legacy_match_lc_id(chunks, "ls", "run-3", runs) is None
assert chunks == [{"id": None, "name": None}]
assert runs == {}
# ---------------------------------------------------------------------------
# parity_v2 wire format tests.
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
"""First chunk carries id+name; later idless chunks at the same
``index`` merge into the SAME ``tool-input-start`` ui id and emit
one ``tool-input-delta`` per chunk."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
],
),
_model_stream(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
],
),
_tool_start(
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
deltas = _of_type(payloads, "tool-input-delta")
available = _of_type(payloads, "tool-input-available")
output = _of_type(payloads, "tool-output-available")
assert len(starts) == 1
assert starts[0]["toolCallId"] == "lc-1"
assert starts[0]["toolName"] == "write_file"
assert starts[0]["langchainToolCallId"] == "lc-1"
assert [d["inputTextDelta"] for d in deltas] == ['{"file', '_path":"/x"}']
assert all(d["toolCallId"] == "lc-1" for d in deltas)
assert len(available) == 1
assert available[0]["toolCallId"] == "lc-1"
assert len(output) == 1
assert output[0]["toolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_two_interleaved_tool_calls_route_by_index(
parity_v2_on: None,
) -> None:
"""Two same-name calls with distinct indices keep their deltas
routed to the right card."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-A", "name": "write_file", "args": '{"a":1', "index": 0},
{"id": "lc-B", "name": "write_file", "args": '{"b":2', "index": 1},
]
),
_model_stream(
tool_call_chunks=[
{"id": None, "name": None, "args": "}", "index": 0},
{"id": None, "name": None, "args": "}", "index": 1},
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={"a": 1}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-A"),
_tool_start(name="write_file", run_id="run-B", input_payload={"b": 2}),
_tool_end(name="write_file", run_id="run-B", tool_call_id="lc-B"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
deltas = _of_type(payloads, "tool-input-delta")
output = _of_type(payloads, "tool-output-available")
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
by_id: dict[str, list[str]] = {"lc-A": [], "lc-B": []}
for d in deltas:
by_id[d["toolCallId"]].append(d["inputTextDelta"])
assert by_id["lc-A"] == ['{"a":1', "}"]
assert by_id["lc-B"] == ['{"b":2', "}"]
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
@pytest.mark.asyncio
async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
"""Whatever id ``tool-input-start`` chose must be the SAME id used
on ``tool-input-available`` AND ``tool-output-available``."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-9", "name": "ls", "args": '{"path":"/"}', "index": 0}
]
),
_tool_start(name="ls", run_id="run-X", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-X", tool_call_id="lc-9"),
]
payloads = await _drain(events)
relevant = [
p
for p in payloads
if p.get("type")
in {"tool-input-start", "tool-input-available", "tool-output-available"}
]
assert {p["toolCallId"] for p in relevant} == {"lc-9"}
@pytest.mark.asyncio
async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
"""When the chunk-emission loop already fired ``tool-input-start``
for this run, ``on_tool_start`` MUST NOT emit a second one."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_active_text_closes_before_early_tool_input_start(
parity_v2_on: None,
) -> None:
"""Streaming a text-delta then a tool-call chunk in subsequent
chunks: the wire MUST contain ``text-end`` before the FIRST
``tool-input-start`` (clean part boundary on the frontend)."""
events = [
_model_stream(text="Working on it"),
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
types = _types(await _drain(events))
text_end_idx = types.index("text-end")
start_idx = types.index("tool-input-start")
assert text_end_idx < start_idx
@pytest.mark.asyncio
async def test_mixed_text_and_tool_chunk_preserve_order(
parity_v2_on: None,
) -> None:
"""One AIMessageChunk that carries BOTH ``text`` content AND
``tool_call_chunks`` should emit the text delta FIRST, then close
text, then ``tool-input-start``+``tool-input-delta``."""
events = [
_model_stream(
text="I'll update it",
tool_call_chunks=[
{
"id": "lc-1",
"name": "write_file",
"args": '{"file_path":"/x"}',
"index": 0,
}
],
),
_tool_start(
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
types = _types(await _drain(events))
# text-start … text-delta … text-end … tool-input-start … tool-input-delta
assert types.index("text-start") < types.index("text-delta")
assert types.index("text-delta") < types.index("text-end")
assert types.index("text-end") < types.index("tool-input-start")
assert types.index("tool-input-start") < types.index("tool-input-delta")
@pytest.mark.asyncio
async def test_parity_v2_off_preserves_legacy_shape(
parity_v2_off: None,
) -> None:
"""When the flag is OFF, no deltas are emitted and the ``toolCallId``
is ``call_<run_id>`` (NOT the lc id)."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0}
]
),
_tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
assert _of_type(payloads, "tool-input-delta") == []
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-A")
# No ``langchainToolCallId`` propagation on ``tool-input-start`` in
# legacy mode (the start event fires before the ToolMessage is
# available, so we can't extract the authoritative LangChain id yet).
assert "langchainToolCallId" not in starts[0]
output = _of_type(payloads, "tool-output-available")
assert output[0]["toolCallId"].startswith("call_run-A")
# ``tool-output-available`` MUST carry ``langchainToolCallId`` even
# in legacy mode: the chat tool card uses it to backfill the
# LangChain id and join against the ``data-action-log`` SSE event
# (keyed by ``lc_tool_call_id``) so the inline Revert button can
# light up. Sourced from the returned ``ToolMessage.tool_call_id``,
# which is populated regardless of feature-flag state.
assert output[0]["langchainToolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_skip_append_prevents_stale_id_reuse(
parity_v2_on: None,
) -> None:
"""Two same-name tools: the SECOND tool's ``langchainToolCallId``
must NOT come from the first tool's chunk (``pending_tool_call_chunks``
must stay empty for indexed-registered chunks)."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-A", "name": "write_file", "args": "{}", "index": 0},
{"id": "lc-B", "name": "write_file", "args": "{}", "index": 1},
]
),
_tool_start(name="write_file", run_id="run-1", input_payload={}),
_tool_end(name="write_file", run_id="run-1", tool_call_id="lc-A"),
_tool_start(name="write_file", run_id="run-2", input_payload={}),
_tool_end(name="write_file", run_id="run-2", tool_call_id="lc-B"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
# Two distinct lc ids, each its own card.
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
# Each tool-output-available landed on its respective card.
output = _of_type(payloads, "tool-output-available")
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
@pytest.mark.asyncio
async def test_registration_waits_for_both_id_and_name(
parity_v2_on: None,
) -> None:
"""An id-only chunk (no name yet) must NOT emit ``tool-input-start``."""
events = [
_model_stream(
tool_call_chunks=[{"id": "lc-1", "name": None, "args": "", "index": 0}]
),
]
payloads = await _drain(events)
assert _of_type(payloads, "tool-input-start") == []
@pytest.mark.asyncio
async def test_unmatched_fallback_still_attaches_lc_id(
parity_v2_on: None,
) -> None:
"""parity_v2 ON, but the provider didn't include an ``index``: the
legacy fallback path must still emit ``tool-input-start`` with the
matching ``langchainToolCallId``."""
events = [
# No index on the chunk → not registered into index_to_meta;
# falls through to ``pending_tool_call_chunks`` so the legacy
# match path can pop it at on_tool_start.
_model_stream(tool_call_chunks=[{"id": "lc-orphan", "name": "ls", "args": ""}]),
_tool_start(name="ls", run_id="run-1", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-1", tool_call_id="lc-orphan"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-1")
assert starts[0]["langchainToolCallId"] == "lc-orphan"

View file

@ -14,13 +14,6 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner"; import { toast } from "sonner";
import { z } from "zod"; import { z } from "zod";
import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms"; import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms";
import {
agentActionsByChatTurnIdAtom,
markAgentActionRevertedAtom,
resetAgentActionMapAtom,
updateAgentActionReversibleAtom,
upsertAgentActionAtom,
} from "@/atoms/chat/agent-actions.atom";
import { import {
clearTargetCommentIdAtom, clearTargetCommentIdAtom,
currentThreadAtom, currentThreadAtom,
@ -56,6 +49,12 @@ import {
type TokenUsageData, type TokenUsageData,
TokenUsageProvider, TokenUsageProvider,
} from "@/components/assistant-ui/token-usage-context"; } from "@/components/assistant-ui/token-usage-context";
import {
applyActionLogSse,
applyActionLogUpdatedSse,
markActionRevertedInCache,
useAgentActionsQuery,
} from "@/hooks/use-agent-actions-query";
import { useChatSessionStateSync } from "@/hooks/use-chat-session-state"; import { useChatSessionStateSync } from "@/hooks/use-chat-session-state";
import { useMessagesSync } from "@/hooks/use-messages-sync"; import { useMessagesSync } from "@/hooks/use-messages-sync";
import { getAgentFilesystemSelection } from "@/lib/agent-filesystem"; import { getAgentFilesystemSelection } from "@/lib/agent-filesystem";
@ -76,12 +75,12 @@ import {
addToolCall, addToolCall,
appendReasoning, appendReasoning,
appendText, appendText,
appendToolInputDelta,
buildContentForPersistence, buildContentForPersistence,
buildContentForUI, buildContentForUI,
type ContentPartsState, type ContentPartsState,
endReasoning, endReasoning,
FrameBatchedUpdater, FrameBatchedUpdater,
findToolCallIdByLcId,
readSSEStream, readSSEStream,
type SSEEvent, type SSEEvent,
type ThinkingStepData, type ThinkingStepData,
@ -511,14 +510,6 @@ export default function NewChatPage() {
const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom); const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom);
const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom);
const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom); const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom);
// Agent action log SSE side-channel.
const upsertAgentAction = useSetAtom(upsertAgentActionAtom);
const updateAgentActionReversible = useSetAtom(updateAgentActionReversibleAtom);
const markAgentActionReverted = useSetAtom(markAgentActionRevertedAtom);
const resetAgentActionMap = useSetAtom(resetAgentActionMapAtom);
// Chat-turn-keyed action map for the edit-from-position pre-flight
// that decides whether to show the confirmation dialog.
const agentActionsByChatTurnId = useAtomValue(agentActionsByChatTurnIdAtom);
// Edit dialog state. Holds the message id being edited and // Edit dialog state. Holds the message id being edited and
// the (already extracted) regenerate args so we can resume the edit // the (already extracted) regenerate args so we can resume the edit
// after the user picks "revert all" / "continue" / "cancel". // after the user picks "revert all" / "continue" / "cancel".
@ -547,6 +538,11 @@ export default function NewChatPage() {
content: unknown; content: unknown;
author_id: string | null; author_id: string | null;
created_at: string; created_at: string;
// Forwarded so ``convertToThreadMessage`` can rebuild the
// ``metadata.custom.chatTurnId`` on the
// ``ThreadMessageLike``. Required by the inline Revert
// button's per-turn fallback.
turn_id?: string | null;
}[] }[]
) => { ) => {
if (isRunning) { if (isRunning) {
@ -579,6 +575,11 @@ export default function NewChatPage() {
created_at: msg.created_at, created_at: msg.created_at,
author_display_name: member?.user_display_name ?? existingAuthor?.displayName ?? null, author_display_name: member?.user_display_name ?? existingAuthor?.displayName ?? null,
author_avatar_url: member?.user_avatar_url ?? existingAuthor?.avatarUrl ?? null, author_avatar_url: member?.user_avatar_url ?? existingAuthor?.avatarUrl ?? null,
// Forward the per-turn correlation id so the
// inline Revert button's ``(chat_turn_id,
// tool_name, position)`` fallback survives the
// post-stream Zero re-sync.
turn_id: msg.turn_id ?? null,
}); });
}); });
}); });
@ -595,6 +596,13 @@ export default function NewChatPage() {
return Number.isNaN(parsed) ? 0 : parsed; return Number.isNaN(parsed) ? 0 : parsed;
}, [params.search_space_id]); }, [params.search_space_id]);
// Unified store for agent-action rows (the same react-query cache
// the agent-actions sheet, the inline Revert button, and the
// per-turn Revert button all read). Hydrates from
// ``GET /threads/{id}/actions`` and is updated incrementally by the
// SSE handlers + revert-batch results below — no atom side-channel.
const { items: agentActionItems } = useAgentActionsQuery(threadId);
// Extract chat_id from URL params // Extract chat_id from URL params
const urlChatId = useMemo(() => { const urlChatId = useMemo(() => {
const id = params.chat_id; const id = params.chat_id;
@ -738,7 +746,8 @@ export default function NewChatPage() {
clearPlanOwnerRegistry(); clearPlanOwnerRegistry();
closeReportPanel(); closeReportPanel();
closeEditorPanel(); closeEditorPanel();
resetAgentActionMap(); // Note: agent-action data is keyed by threadId in react-query so
// switching threads naturally swaps caches; no explicit reset.
try { try {
if (urlChatId > 0) { if (urlChatId > 0) {
@ -807,7 +816,6 @@ export default function NewChatPage() {
removeChatTab, removeChatTab,
searchSpaceId, searchSpaceId,
tokenUsageStore, tokenUsageStore,
resetAgentActionMap,
]); ]);
// Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same) // Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same)
@ -1145,6 +1153,15 @@ export default function NewChatPage() {
); );
}; };
const scheduleFlush = () => batcher.schedule(flushMessages); const scheduleFlush = () => batcher.schedule(flushMessages);
// Force-flush helper: ``batcher.flush()`` is a no-op when
// ``dirty=false`` (e.g. a tool starts before any text
// streamed). ``scheduleFlush(); batcher.flush()`` sets
// the dirty bit FIRST so terminal events render
// promptly without the 50ms throttle delay.
const forceFlush = () => {
scheduleFlush();
batcher.flush();
};
for await (const parsed of readSSEStream(response)) { for await (const parsed of readSSEStream(response)) {
switch (parsed.type) { switch (parsed.type) {
@ -1181,13 +1198,23 @@ export default function NewChatPage() {
false, false,
parsed.langchainToolCallId parsed.langchainToolCallId
); );
batcher.flush(); forceFlush();
break;
case "tool-input-delta":
// High-frequency event: deltas can fire dozens
// of times per call, so use throttled
// scheduleFlush (NOT forceFlush) to coalesce.
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
scheduleFlush();
break; break;
case "tool-input-available": { case "tool-input-available": {
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
if (toolCallIndices.has(parsed.toolCallId)) { if (toolCallIndices.has(parsed.toolCallId)) {
updateToolCall(contentPartsState, parsed.toolCallId, { updateToolCall(contentPartsState, parsed.toolCallId, {
args: parsed.input || {}, args: parsed.input || {},
argsText: finalArgsText,
langchainToolCallId: parsed.langchainToolCallId, langchainToolCallId: parsed.langchainToolCallId,
}); });
} else { } else {
@ -1200,8 +1227,14 @@ export default function NewChatPage() {
false, false,
parsed.langchainToolCallId parsed.langchainToolCallId
); );
// addToolCall doesn't accept argsText today;
// backfill via updateToolCall so the new card
// renders pretty-printed JSON.
updateToolCall(contentPartsState, parsed.toolCallId, {
argsText: finalArgsText,
});
} }
batcher.flush(); forceFlush();
break; break;
} }
@ -1220,7 +1253,7 @@ export default function NewChatPage() {
} }
} }
} }
batcher.flush(); forceFlush();
break; break;
} }
@ -1316,34 +1349,17 @@ export default function NewChatPage() {
} }
case "data-action-log": { case "data-action-log": {
const al = parsed.data; applyActionLogSse(queryClient, currentThreadId, searchSpaceId, parsed.data);
const matchedToolCallId = al.lc_tool_call_id
? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id)
: null;
upsertAgentAction({
action: {
id: al.id,
threadId: currentThreadId,
lcToolCallId: al.lc_tool_call_id,
chatTurnId: al.chat_turn_id,
toolName: al.tool_name,
reversible: al.reversible,
reverseDescriptorPresent: al.reverse_descriptor_present,
error: al.error,
revertedByActionId: null,
isRevertAction: false,
createdAt: al.created_at,
},
toolCallId: matchedToolCallId,
});
break; break;
} }
case "data-action-log-updated": { case "data-action-log-updated": {
updateAgentActionReversible({ applyActionLogUpdatedSse(
id: parsed.data.id, queryClient,
reversible: parsed.data.reversible, currentThreadId,
}); parsed.data.id,
parsed.data.reversible
);
break; break;
} }
@ -1559,6 +1575,15 @@ export default function NewChatPage() {
toolName: String(p.toolName), toolName: String(p.toolName),
args: (p.args as Record<string, unknown>) ?? {}, args: (p.args as Record<string, unknown>) ?? {},
result: p.result as unknown, result: p.result as unknown,
// Restore argsText so persisted pretty-printed
// JSON survives reloads (assistant-ui prefers
// supplied argsText over JSON.stringify(args)).
// langchainToolCallId restoration also fixes a
// pre-existing dropped-id bug on resume.
...(typeof p.argsText === "string" ? { argsText: p.argsText } : {}),
...(typeof p.langchainToolCallId === "string"
? { langchainToolCallId: p.langchainToolCallId }
: {}),
}); });
contentPartsState.currentTextPartIndex = -1; contentPartsState.currentTextPartIndex = -1;
} else if (p.type === "data-thinking-steps") { } else if (p.type === "data-thinking-steps") {
@ -1580,7 +1605,12 @@ export default function NewChatPage() {
const editedAction = decisions[0].edited_action; const editedAction = decisions[0].edited_action;
for (const part of contentParts) { for (const part of contentParts) {
if (part.type === "tool-call" && part.toolName === editedAction.name) { if (part.type === "tool-call" && part.toolName === editedAction.name) {
part.args = { ...part.args, ...editedAction.args }; const mergedArgs = { ...part.args, ...editedAction.args };
part.args = mergedArgs;
// Sync argsText so the rendered card shows the
// edited inputs — assistant-ui prefers caller-
// supplied argsText over JSON.stringify(args).
part.argsText = JSON.stringify(mergedArgs, null, 2);
break; break;
} }
} }
@ -1637,6 +1667,10 @@ export default function NewChatPage() {
); );
}; };
const scheduleFlush = () => batcher.schedule(flushMessages); const scheduleFlush = () => batcher.schedule(flushMessages);
const forceFlush = () => {
scheduleFlush();
batcher.flush();
};
for await (const parsed of readSSEStream(response)) { for await (const parsed of readSSEStream(response)) {
switch (parsed.type) { switch (parsed.type) {
@ -1673,13 +1707,20 @@ export default function NewChatPage() {
false, false,
parsed.langchainToolCallId parsed.langchainToolCallId
); );
batcher.flush(); forceFlush();
break; break;
case "tool-input-available": case "tool-input-delta":
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
scheduleFlush();
break;
case "tool-input-available": {
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
if (toolCallIndices.has(parsed.toolCallId)) { if (toolCallIndices.has(parsed.toolCallId)) {
updateToolCall(contentPartsState, parsed.toolCallId, { updateToolCall(contentPartsState, parsed.toolCallId, {
args: parsed.input || {}, args: parsed.input || {},
argsText: finalArgsText,
langchainToolCallId: parsed.langchainToolCallId, langchainToolCallId: parsed.langchainToolCallId,
}); });
} else { } else {
@ -1692,9 +1733,13 @@ export default function NewChatPage() {
false, false,
parsed.langchainToolCallId parsed.langchainToolCallId
); );
updateToolCall(contentPartsState, parsed.toolCallId, {
argsText: finalArgsText,
});
} }
batcher.flush(); forceFlush();
break; break;
}
case "tool-output-available": case "tool-output-available":
updateToolCall(contentPartsState, parsed.toolCallId, { updateToolCall(contentPartsState, parsed.toolCallId, {
@ -1702,7 +1747,7 @@ export default function NewChatPage() {
langchainToolCallId: parsed.langchainToolCallId, langchainToolCallId: parsed.langchainToolCallId,
}); });
markInterruptsCompleted(contentParts); markInterruptsCompleted(contentParts);
batcher.flush(); forceFlush();
break; break;
case "data-thinking-step": { case "data-thinking-step": {
@ -1762,34 +1807,17 @@ export default function NewChatPage() {
} }
case "data-action-log": { case "data-action-log": {
const al = parsed.data; applyActionLogSse(queryClient, resumeThreadId, searchSpaceId, parsed.data);
const matchedToolCallId = al.lc_tool_call_id
? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id)
: null;
upsertAgentAction({
action: {
id: al.id,
threadId: resumeThreadId,
lcToolCallId: al.lc_tool_call_id,
chatTurnId: al.chat_turn_id,
toolName: al.tool_name,
reversible: al.reversible,
reverseDescriptorPresent: al.reverse_descriptor_present,
error: al.error,
revertedByActionId: null,
isRevertAction: false,
createdAt: al.created_at,
},
toolCallId: matchedToolCallId,
});
break; break;
} }
case "data-action-log-updated": { case "data-action-log-updated": {
updateAgentActionReversible({ applyActionLogUpdatedSse(
id: parsed.data.id, queryClient,
reversible: parsed.data.reversible, resumeThreadId,
}); parsed.data.id,
parsed.data.reversible
);
break; break;
} }
@ -1902,6 +1930,11 @@ export default function NewChatPage() {
return { return {
...part, ...part,
args: decision.edited_action.args, // Update displayed args args: decision.edited_action.args, // Update displayed args
// Sync argsText so the rendered card shows
// the edited inputs — assistant-ui prefers
// caller-supplied argsText over
// JSON.stringify(args).
argsText: JSON.stringify(decision.edited_action.args, null, 2),
result: { result: {
...(part.result as Record<string, unknown>), ...(part.result as Record<string, unknown>),
__decided__: decisionType, __decided__: decisionType,
@ -2127,6 +2160,10 @@ export default function NewChatPage() {
); );
}; };
const scheduleFlush = () => batcher.schedule(flushMessages); const scheduleFlush = () => batcher.schedule(flushMessages);
const forceFlush = () => {
scheduleFlush();
batcher.flush();
};
for await (const parsed of readSSEStream(response)) { for await (const parsed of readSSEStream(response)) {
switch (parsed.type) { switch (parsed.type) {
@ -2163,13 +2200,20 @@ export default function NewChatPage() {
false, false,
parsed.langchainToolCallId parsed.langchainToolCallId
); );
batcher.flush(); forceFlush();
break; break;
case "tool-input-available": case "tool-input-delta":
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
scheduleFlush();
break;
case "tool-input-available": {
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
if (toolCallIndices.has(parsed.toolCallId)) { if (toolCallIndices.has(parsed.toolCallId)) {
updateToolCall(contentPartsState, parsed.toolCallId, { updateToolCall(contentPartsState, parsed.toolCallId, {
args: parsed.input || {}, args: parsed.input || {},
argsText: finalArgsText,
langchainToolCallId: parsed.langchainToolCallId, langchainToolCallId: parsed.langchainToolCallId,
}); });
} else { } else {
@ -2182,9 +2226,13 @@ export default function NewChatPage() {
false, false,
parsed.langchainToolCallId parsed.langchainToolCallId
); );
updateToolCall(contentPartsState, parsed.toolCallId, {
argsText: finalArgsText,
});
} }
batcher.flush(); forceFlush();
break; break;
}
case "tool-output-available": case "tool-output-available":
updateToolCall(contentPartsState, parsed.toolCallId, { updateToolCall(contentPartsState, parsed.toolCallId, {
@ -2201,7 +2249,7 @@ export default function NewChatPage() {
} }
} }
} }
batcher.flush(); forceFlush();
break; break;
case "data-thinking-step": { case "data-thinking-step": {
@ -2217,34 +2265,21 @@ export default function NewChatPage() {
} }
case "data-action-log": { case "data-action-log": {
const al = parsed.data; if (threadId !== null) {
const matchedToolCallId = al.lc_tool_call_id applyActionLogSse(queryClient, threadId, searchSpaceId, parsed.data);
? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) }
: null;
upsertAgentAction({
action: {
id: al.id,
threadId,
lcToolCallId: al.lc_tool_call_id,
chatTurnId: al.chat_turn_id,
toolName: al.tool_name,
reversible: al.reversible,
reverseDescriptorPresent: al.reverse_descriptor_present,
error: al.error,
revertedByActionId: null,
isRevertAction: false,
createdAt: al.created_at,
},
toolCallId: matchedToolCallId,
});
break; break;
} }
case "data-action-log-updated": { case "data-action-log-updated": {
updateAgentActionReversible({ if (threadId !== null) {
id: parsed.data.id, applyActionLogUpdatedSse(
reversible: parsed.data.reversible, queryClient,
}); threadId,
parsed.data.id,
parsed.data.reversible
);
}
break; break;
} }
@ -2281,12 +2316,16 @@ export default function NewChatPage() {
: `Reverted ${summary.reverted} downstream actions before regenerating.` : `Reverted ${summary.reverted} downstream actions before regenerating.`
); );
} }
for (const r of summary.results) { if (threadId !== null) {
if (r.status === "reverted" || r.status === "already_reverted") { for (const r of summary.results) {
markAgentActionReverted({ if (r.status === "reverted" || r.status === "already_reverted") {
id: r.action_id, markActionRevertedInCache(
newActionId: r.new_action_id ?? null, queryClient,
}); threadId,
r.action_id,
r.new_action_id ?? null
);
}
} }
} }
break; break;
@ -2459,16 +2498,26 @@ export default function NewChatPage() {
const downstream = messages.slice(editedIndex + 1); const downstream = messages.slice(editedIndex + 1);
downstreamTotalCount = downstream.length; downstreamTotalCount = downstream.length;
const seenTurns = new Set<string>(); const seenTurns = new Set<string>();
const downstreamTurnIds = new Set<string>();
for (const m of downstream) { for (const m of downstream) {
const meta = (m.metadata ?? {}) as { custom?: { chatTurnId?: string } }; const meta = (m.metadata ?? {}) as { custom?: { chatTurnId?: string } };
const tid = meta.custom?.chatTurnId; const tid = meta.custom?.chatTurnId;
if (!tid || seenTurns.has(tid)) continue; if (!tid || seenTurns.has(tid)) continue;
seenTurns.add(tid); seenTurns.add(tid);
const turnActions = agentActionsByChatTurnId.get(tid) ?? []; downstreamTurnIds.add(tid);
for (const a of turnActions) { }
if (a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error) { // Source of truth: the unified react-query cache. Every
downstreamReversibleCount += 1; // action whose ``chat_turn_id`` belongs to the slice we're
} // about to drop counts toward the prompt.
for (const a of agentActionItems) {
if (!a.chat_turn_id || !downstreamTurnIds.has(a.chat_turn_id)) continue;
if (
a.reversible &&
(a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) &&
!a.is_revert_action &&
(a.error === null || a.error === undefined)
) {
downstreamReversibleCount += 1;
} }
} }
} }
@ -2492,7 +2541,7 @@ export default function NewChatPage() {
downstreamTotalCount, downstreamTotalCount,
}); });
}, },
[handleRegenerate, messages, agentActionsByChatTurnId] [handleRegenerate, messages, agentActionItems]
); );
const handleEditDialogChoice = useCallback( const handleEditDialogChoice = useCallback(

View file

@ -1,194 +0,0 @@
"use client";
import { atom } from "jotai";
/**
* Minimal per-row projection of ``AgentActionLog`` that the tool card
* needs to decide whether to render a Revert button.
*
* Fields are deliberately a subset of the full ``AgentAction`` so the
* SSE side-channel (``data-action-log`` / ``data-action-log-updated``)
* can populate them without depending on the REST endpoint
* ``GET /threads/.../actions`` (which 503s when
* ``SURFSENSE_ENABLE_ACTION_LOG`` is off).
*/
export interface AgentActionLite {
id: number;
threadId: number | null;
lcToolCallId: string | null;
chatTurnId: string | null;
toolName: string;
reversible: boolean;
reverseDescriptorPresent: boolean;
error: boolean;
revertedByActionId: number | null;
isRevertAction: boolean;
createdAt: string | null;
}
/**
* Map keyed off the LangChain ``tool_call.id`` (mirrors ``ContentPart
* tool-call.langchainToolCallId``).
*/
export const agentActionByLcIdAtom = atom<Map<string, AgentActionLite>>(new Map());
/**
* Parallel map keyed off the synthetic chat-card ``toolCallId``
* (``call_<run-id>``) so ``ToolFallback`` (which only receives the
* synthetic id from assistant-ui) can join its card to the action log.
*
* Both maps are kept in sync by ``upsertAgentActionAtom``.
*/
export const agentActionByToolCallIdAtom = atom<Map<string, AgentActionLite>>(new Map());
/**
* Index keyed by ``chat_turn_id`` so the per-turn revert UI can answer
* "how many reversible actions does this assistant turn contain?" in
* O(1). Each entry's array is ordered by insertion (which
* for a single turn matches ``created_at`` because action-log writes
* happen synchronously).
*/
export const agentActionsByChatTurnIdAtom = atom<Map<string, AgentActionLite[]>>(new Map());
/**
* Action to upsert one ``AgentActionLite`` row.
*
* ``toolCallId`` is the synthetic card id (``call_<run-id>`` from
* ``stream_new_chat.py``). When provided alongside ``lcToolCallId``, the
* action is indexed under BOTH ids so the tool card can perform the
* lookup without going via the streaming state.
*/
export const upsertAgentActionAtom = atom(
null,
(_get, set, payload: { action: AgentActionLite; toolCallId?: string | null }) => {
const { action, toolCallId } = payload;
const upsertInto = (
prev: Map<string, AgentActionLite>,
key: string
): Map<string, AgentActionLite> => {
const next = new Map(prev);
const existing = next.get(key);
next.set(key, {
...action,
// Preserve the local "reverted" bookkeeping if a reversibility
// flip arrives AFTER the user already reverted via the REST
// route. We never want a stale ``reversible=true`` event to
// resurrect a Reverted card.
revertedByActionId: existing?.revertedByActionId ?? action.revertedByActionId,
isRevertAction: existing?.isRevertAction ?? action.isRevertAction,
});
return next;
};
if (action.lcToolCallId) {
set(agentActionByLcIdAtom, (prev) => upsertInto(prev, action.lcToolCallId as string));
}
if (toolCallId) {
set(agentActionByToolCallIdAtom, (prev) => upsertInto(prev, toolCallId));
}
if (action.chatTurnId) {
set(agentActionsByChatTurnIdAtom, (prev) => {
const next = new Map(prev);
const turnId = action.chatTurnId as string;
const existing = next.get(turnId) ?? [];
const priorEntry = existing.find((row) => row.id === action.id);
const merged: AgentActionLite = {
...action,
revertedByActionId: priorEntry?.revertedByActionId ?? action.revertedByActionId,
isRevertAction: priorEntry?.isRevertAction ?? action.isRevertAction,
};
const others = existing.filter((row) => row.id !== action.id);
next.set(turnId, [...others, merged]);
return next;
});
}
}
);
function mutateById(
prev: Map<string, AgentActionLite>,
id: number,
mutator: (entry: AgentActionLite) => AgentActionLite
): Map<string, AgentActionLite> {
let mutated = false;
const next = new Map(prev);
for (const [key, value] of next) {
if (value.id === id) {
next.set(key, mutator(value));
mutated = true;
}
}
return mutated ? next : prev;
}
function mutateByIdInTurnIndex(
prev: Map<string, AgentActionLite[]>,
id: number,
mutator: (entry: AgentActionLite) => AgentActionLite
): Map<string, AgentActionLite[]> {
let mutated = false;
const next = new Map(prev);
for (const [key, list] of next) {
let listMutated = false;
const updated = list.map((row) => {
if (row.id === id) {
listMutated = true;
return mutator(row);
}
return row;
});
if (listMutated) {
next.set(key, updated);
mutated = true;
}
}
return mutated ? next : prev;
}
/**
* Action to flip an existing entry's ``reversible`` flag, keyed by the
* AgentActionLog row id (the SSE ``data-action-log-updated`` payload
* does NOT carry ``lcToolCallId``).
*/
export const updateAgentActionReversibleAtom = atom(
null,
(_get, set, payload: { id: number; reversible: boolean }) => {
const apply = (entry: AgentActionLite): AgentActionLite => ({
...entry,
reversible: payload.reversible,
});
set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply));
set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply));
set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply));
}
);
/** Action to mark an existing entry as reverted (post-revert call). */
export const markAgentActionRevertedAtom = atom(
null,
(_get, set, payload: { id: number; newActionId: number | null }) => {
const apply = (entry: AgentActionLite): AgentActionLite => ({
...entry,
revertedByActionId: payload.newActionId ?? -1,
});
set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply));
set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply));
set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply));
}
);
/** Mark every action in a turn as reverted, given a list of (id, newActionId) pairs. */
export const markAgentActionsRevertedBatchAtom = atom(
null,
(_get, set, payload: { entries: Array<{ id: number; newActionId: number | null }> }) => {
for (const entry of payload.entries) {
set(markAgentActionRevertedAtom, entry);
}
}
);
/** Reset all maps (e.g. when the active thread changes). */
export const resetAgentActionMapAtom = atom(null, (_get, set) => {
set(agentActionByLcIdAtom, new Map());
set(agentActionByToolCallIdAtom, new Map());
set(agentActionsByChatTurnIdAtom, new Map());
});

View file

@ -1,9 +1,9 @@
"use client"; "use client";
import { useQuery, useQueryClient } from "@tanstack/react-query"; import { useQueryClient } from "@tanstack/react-query";
import { useAtom, useAtomValue } from "jotai"; import { useAtom, useAtomValue } from "jotai";
import { Activity, RefreshCcw } from "lucide-react"; import { Activity, RefreshCcw } from "lucide-react";
import { useCallback, useMemo } from "react"; import { useCallback } from "react";
import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom"; import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom";
import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom";
import { Badge } from "@/components/ui/badge"; import { Badge } from "@/components/ui/badge";
@ -17,15 +17,12 @@ import {
SheetTitle, SheetTitle,
} from "@/components/ui/sheet"; } from "@/components/ui/sheet";
import { Skeleton } from "@/components/ui/skeleton"; import { Skeleton } from "@/components/ui/skeleton";
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; import {
agentActionsQueryKey,
useAgentActionsQuery,
} from "@/hooks/use-agent-actions-query";
import { ActionLogItem } from "./action-log-item"; import { ActionLogItem } from "./action-log-item";
const ACTION_LOG_PAGE_SIZE = 50;
function actionLogQueryKey(threadId: number) {
return ["agent-actions", threadId] as const;
}
function EmptyState() { function EmptyState() {
return ( return (
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-6 text-center"> <div className="flex flex-1 flex-col items-center justify-center gap-3 px-6 text-center">
@ -85,25 +82,17 @@ export function ActionLogSheet() {
const threadId = state.threadId; const threadId = state.threadId;
const { data, isLoading, isFetching, isError, error, refetch } = useQuery({ const { data, items, isLoading, isFetching, isError, error, refetch } = useAgentActionsQuery(
queryKey: threadId !== null ? actionLogQueryKey(threadId) : ["agent-actions", "none"], threadId,
queryFn: () => { enabled: state.open && actionLogEnabled }
agentActionsApiService.listForThread(threadId as number, { );
page: 0,
pageSize: ACTION_LOG_PAGE_SIZE,
}),
enabled: state.open && threadId !== null && actionLogEnabled,
staleTime: 15 * 1000,
});
const handleRevertSuccess = useCallback(() => { const handleRevertSuccess = useCallback(() => {
if (threadId !== null) { if (threadId !== null) {
queryClient.invalidateQueries({ queryKey: actionLogQueryKey(threadId) }); queryClient.invalidateQueries({ queryKey: agentActionsQueryKey(threadId) });
} }
}, [queryClient, threadId]); }, [queryClient, threadId]);
const items = useMemo(() => data?.items ?? [], [data]);
return ( return (
<Sheet open={state.open} onOpenChange={(open) => setState((s) => ({ ...s, open }))}> <Sheet open={state.open} onOpenChange={(open) => setState((s) => ({ ...s, open }))}>
<SheetContent <SheetContent

View file

@ -4,26 +4,22 @@
* "Revert turn" button rendered at the bottom of every completed * "Revert turn" button rendered at the bottom of every completed
* assistant turn that has at least one reversible action. * assistant turn that has at least one reversible action.
* *
* The button reads the action map keyed by ``chat_turn_id`` from the * The button reads from the unified ``useAgentActionsQuery`` cache
* SSE side-channel (``data-action-log`` events). It shows a confirmation * (the SAME react-query cache the agent-actions sheet and the inline
* dialog summarising "N reversible / M total" and, on confirm, calls * Revert button consume) filtered by ``chat_turn_id``. It shows a
* ``POST /threads/{id}/revert-turn/{chat_turn_id}``. * confirmation dialog summarising "N reversible / M total" and, on
* confirm, calls ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
* *
* The route returns a per-action result list and never collapses the * The route returns a per-action result list and never collapses the
* batch into a 4xx so we render any failed/not_reversible rows inline * batch into a 4xx so we render any failed/not_reversible rows inline
* with their messages. * with their messages.
*/ */
import { useAtomValue, useSetAtom } from "jotai"; import { useQueryClient } from "@tanstack/react-query";
import { selectAtom } from "jotai/utils"; import { useAtomValue } from "jotai";
import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react"; import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react";
import { useMemo, useState } from "react"; import { useMemo, useState } from "react";
import { toast } from "sonner"; import { toast } from "sonner";
import {
type AgentActionLite,
agentActionsByChatTurnIdAtom,
markAgentActionsRevertedBatchAtom,
} from "@/atoms/chat/agent-actions.atom";
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
import { import {
AlertDialog, AlertDialog,
@ -38,6 +34,10 @@ import {
} from "@/components/ui/alert-dialog"; } from "@/components/ui/alert-dialog";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { getToolDisplayName } from "@/contracts/enums/toolIcons"; import { getToolDisplayName } from "@/contracts/enums/toolIcons";
import {
applyRevertTurnResultsToCache,
useAgentActionsQuery,
} from "@/hooks/use-agent-actions-query";
import { import {
agentActionsApiService, agentActionsApiService,
type RevertTurnActionResult, type RevertTurnActionResult,
@ -49,49 +49,33 @@ interface RevertTurnButtonProps {
chatTurnId: string | null | undefined; chatTurnId: string | null | undefined;
} }
// Empty-array sentinel so the per-turn ``selectAtom`` slice returns a
// stable reference when the turn has no recorded actions yet. Without
// this every render allocates a fresh ``[]`` and Jotai's
// equality check would re-render the button on unrelated turn updates.
const EMPTY_ACTIONS: readonly AgentActionLite[] = Object.freeze([]);
export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) { export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) {
const session = useAtomValue(chatSessionStateAtom); const session = useAtomValue(chatSessionStateAtom);
const markRevertedBatch = useSetAtom(markAgentActionsRevertedBatchAtom); const threadId = session?.threadId ?? null;
const queryClient = useQueryClient();
const { findByChatTurnId } = useAgentActionsQuery(threadId);
const [isReverting, setIsReverting] = useState(false); const [isReverting, setIsReverting] = useState(false);
const [confirmOpen, setConfirmOpen] = useState(false); const [confirmOpen, setConfirmOpen] = useState(false);
const [resultsOpen, setResultsOpen] = useState(false); const [resultsOpen, setResultsOpen] = useState(false);
const [results, setResults] = useState<RevertTurnActionResult[]>([]); const [results, setResults] = useState<RevertTurnActionResult[]>([]);
// Subscribe ONLY to the slice of the global action map that belongs const actions = useMemo(() => findByChatTurnId(chatTurnId), [findByChatTurnId, chatTurnId]);
// to ``chatTurnId``. Previously the button read the whole
// ``agentActionsByChatTurnIdAtom``, which meant every action
// upsert (one per tool call) re-rendered every Revert button on
// the page. With ``selectAtom`` we re-render only when our turn's
// list reference changes — and the upsert/mark atoms produce a
// fresh list reference for the affected turn only.
const sliceAtom = useMemo(
() =>
selectAtom(
agentActionsByChatTurnIdAtom,
(turnIndex) => (chatTurnId ? turnIndex.get(chatTurnId) : undefined) ?? EMPTY_ACTIONS
),
[chatTurnId]
);
const actions = useAtomValue(sliceAtom);
const reversibleCount = useMemo( const reversibleCount = useMemo(
() => () =>
actions.filter( actions.filter(
(a) => a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error (a) =>
a.reversible &&
(a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) &&
!a.is_revert_action &&
(a.error === null || a.error === undefined)
).length, ).length,
[actions] [actions]
); );
const totalCount = useMemo(() => actions.filter((a) => !a.isRevertAction).length, [actions]); const totalCount = useMemo(() => actions.filter((a) => !a.is_revert_action).length, [actions]);
if (!chatTurnId) return null; if (!chatTurnId) return null;
if (reversibleCount === 0) return null; if (reversibleCount === 0) return null;
const threadId = session?.threadId;
if (!threadId) return null; if (!threadId) return null;
const handleRevertTurn = async () => { const handleRevertTurn = async () => {
@ -103,7 +87,7 @@ export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) {
.filter((r) => r.status === "reverted" || r.status === "already_reverted") .filter((r) => r.status === "reverted" || r.status === "already_reverted")
.map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null })); .map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null }));
if (revertedEntries.length > 0) { if (revertedEntries.length > 0) {
markRevertedBatch({ entries: revertedEntries }); applyRevertTurnResultsToCache(queryClient, threadId, revertedEntries);
} }
if (response.status === "ok") { if (response.status === "ok") {
toast.success( toast.success(

View file

@ -1,12 +1,12 @@
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
import { useAtomValue, useSetAtom } from "jotai";
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, RotateCcw, XCircleIcon } from "lucide-react";
import { useMemo, useState } from "react";
import { toast } from "sonner";
import { import {
agentActionByToolCallIdAtom, type ToolCallMessagePartComponent,
markAgentActionRevertedAtom, useAuiState,
} from "@/atoms/chat/agent-actions.atom"; } from "@assistant-ui/react";
import { useQueryClient } from "@tanstack/react-query";
import { useAtomValue } from "jotai";
import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react";
import { useEffect, useMemo, useState } from "react";
import { toast } from "sonner";
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
import { import {
DoomLoopApprovalToolUI, DoomLoopApprovalToolUI,
@ -24,8 +24,17 @@ import {
AlertDialogTitle, AlertDialogTitle,
AlertDialogTrigger, AlertDialogTrigger,
} from "@/components/ui/alert-dialog"; } from "@/components/ui/alert-dialog";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons"; import { Card } from "@/components/ui/card";
import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible";
import { Separator } from "@/components/ui/separator";
import { Spinner } from "@/components/ui/spinner";
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
import {
markActionRevertedInCache,
useAgentActionsQuery,
} from "@/hooks/use-agent-actions-query";
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
import { AppError } from "@/lib/error"; import { AppError } from "@/lib/error";
import { isInterruptResult } from "@/lib/hitl"; import { isInterruptResult } from "@/lib/hitl";
@ -34,31 +43,128 @@ import { cn } from "@/lib/utils";
/** /**
* Inline Revert button rendered on a tool card when the matching * Inline Revert button rendered on a tool card when the matching
* ``AgentActionLog`` row is reversible and hasn't been reverted yet. * ``AgentActionLog`` row is reversible and hasn't been reverted yet.
* Reads from the SSE side-channel atom keyed by the synthetic *
* ``toolCallId`` so it lights up even when ``GET /threads/.../actions`` * Reads from the unified ``useAgentActionsQuery`` cache the SAME
* is gated behind ``SURFSENSE_ENABLE_ACTION_LOG=False`` (503). * react-query cache the agent-actions sheet consumes. SSE events
* (``data-action-log`` / ``data-action-log-updated``) and
* ``POST /threads/{id}/revert/{id}`` responses both flow through the
* cache via ``setQueryData`` helpers, so the card and the sheet stay
* in lockstep on every code path: page reload, navigation, live
* stream, post-stream reversibility flip, and explicit revert clicks.
*
* Match key (in priority order):
* 1. ``a.tool_call_id === toolCallId`` direct hit in parity_v2 when
* the model streamed ``tool_call_chunks`` so the card's synthetic
* id IS the LangChain id.
* 2. ``a.tool_call_id === langchainToolCallId`` legacy mode (or
* parity_v2 with provider-side chunk emission) where the card's
* synthetic id is ``call_<run_id>`` and the LangChain id is
* backfilled onto the part by ``tool-output-available``.
* 3. ``(chat_turn_id, tool_name, position-within-turn)`` fallback
* for cards whose synthetic id is ``call_<run_id>`` AND whose
* ``langchainToolCallId`` never got backfilled (provider emitted
* the tool_call as a single payload with no chunks AND streaming
* pre-dated the ``tool-output-available langchainToolCallId``
* backfill, e.g. older threads). Reads the parent message's
* ``chatTurnId`` and ``content`` via ``useAuiState`` so we can
* match position-by-tool-name within the turn against the
* action_log rows the server returned in ``created_at`` order.
*/ */
function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { function ToolCardRevertButton({
toolCallId,
toolName,
langchainToolCallId,
}: {
toolCallId: string;
toolName: string;
langchainToolCallId?: string;
}) {
const session = useAtomValue(chatSessionStateAtom); const session = useAtomValue(chatSessionStateAtom);
const actionMap = useAtomValue(agentActionByToolCallIdAtom); const threadId = session?.threadId ?? null;
const markReverted = useSetAtom(markAgentActionRevertedAtom); const queryClient = useQueryClient();
const action = actionMap.get(toolCallId); const { findByToolCallId, findByChatTurnAndTool } = useAgentActionsQuery(threadId);
// Parent message metadata, read via the narrowest possible
// selectors so this card doesn't re-render on every text-delta of
// every other part in the same message during streaming.
//
// IMPORTANT — ``useAuiState`` re-renders the component whenever the
// returned slice's identity changes. Returning ``message?.content``
// (an array) would re-render on every token because the runtime
// rebuilds the parts array. Returning a PRIMITIVE (the position
// number) lets ``useAuiState``'s ``Object.is`` check short-circuit
// when the position hasn't actually moved — which is the common
// case during text streaming, when only ``text``/``reasoning``
// parts are mutating and the same-toolName tool-call ordering is
// stable. (See Vercel React rule ``rerender-defer-reads``.)
const chatTurnId = useAuiState(({ message }) => {
const meta = message?.metadata as { custom?: { chatTurnId?: string } } | undefined;
return meta?.custom?.chatTurnId ?? null;
});
const positionInTurn = useAuiState(({ message }) => {
const content = message?.content;
if (!Array.isArray(content)) return -1;
let n = -1;
for (const part of content) {
if (
part &&
typeof part === "object" &&
(part as { type?: string }).type === "tool-call" &&
(part as { toolName?: string }).toolName === toolName
) {
n += 1;
if ((part as { toolCallId?: string }).toolCallId === toolCallId) return n;
}
}
return -1;
});
const action = useMemo(() => {
// Tier 1 + 2: O(1) Map-backed direct id match. Covers
// ~all parity_v2 streams and any legacy stream that backfilled
// ``langchainToolCallId`` via ``tool-output-available``.
const direct =
findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId);
if (direct) return direct;
// Tier 3: position-within-turn fallback. Only kicks in when the
// card has a synthetic ``call_<run_id>`` id AND no
// ``langchainToolCallId`` was ever backfilled — i.e. the tool
// was emitted as a single non-chunked payload AND streaming
// pre-dated the on_tool_end backfill.
if (!chatTurnId || positionInTurn < 0) return null;
const turnSameTool = findByChatTurnAndTool(chatTurnId, toolName);
return turnSameTool[positionInTurn] ?? null;
}, [
findByToolCallId,
findByChatTurnAndTool,
toolCallId,
langchainToolCallId,
chatTurnId,
toolName,
positionInTurn,
]);
const [isReverting, setIsReverting] = useState(false); const [isReverting, setIsReverting] = useState(false);
const [confirmOpen, setConfirmOpen] = useState(false); const [confirmOpen, setConfirmOpen] = useState(false);
if (!action) return null; if (!action) return null;
if (!action.reversible) return null; if (!action.reversible) return null;
if (action.revertedByActionId !== null) return null; if (action.reverted_by_action_id !== null && action.reverted_by_action_id !== undefined)
if (action.isRevertAction) return null; return null;
if (action.error) return null; if (action.is_revert_action) return null;
const threadId = session?.threadId; if (action.error !== null && action.error !== undefined) return null;
if (!threadId) return null; if (!threadId) return null;
const handleRevert = async () => { const handleRevert = async () => {
setIsReverting(true); setIsReverting(true);
try { try {
const response = await agentActionsApiService.revert(threadId, action.id); const response = await agentActionsApiService.revert(threadId, action.id);
markReverted({ id: action.id, newActionId: response.new_action_id ?? null }); markActionRevertedInCache(
queryClient,
threadId,
action.id,
response.new_action_id ?? null
);
toast.success(response.message || "Action reverted."); toast.success(response.message || "Action reverted.");
} catch (err) { } catch (err) {
// 503 means revert is gated off on this deployment — hide the // 503 means revert is gated off on this deployment — hide the
@ -91,8 +197,17 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) {
e.stopPropagation(); e.stopPropagation();
setConfirmOpen(true); setConfirmOpen(true);
}} }}
disabled={isReverting}
> >
<RotateCcw className="size-3.5" /> {isReverting ? (
// Spinner's typed props don't accept ``data-icon`` and
// it renders an <output>, not an <svg>, so Button's
// auto-sizing rule doesn't apply. Bare spinner +
// Button's gap handle layout.
<Spinner size="xs" />
) : (
<RotateCcw data-icon="inline-start" />
)}
Revert Revert
</Button> </Button>
</AlertDialogTrigger> </AlertDialogTrigger>
@ -101,7 +216,7 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) {
<AlertDialogTitle>Revert this action?</AlertDialogTitle> <AlertDialogTitle>Revert this action?</AlertDialogTitle>
<AlertDialogDescription> <AlertDialogDescription>
This will undo{" "} This will undo{" "}
<span className="font-medium">{getToolDisplayName(action.toolName)}</span> and add a <span className="font-medium">{getToolDisplayName(action.tool_name)}</span> and add a
new entry to the history. Your chat is preserved only the changes the agent made to new entry to the history. Your chat is preserved only the changes the agent made to
your knowledge base or connected apps will be rolled back where possible. your knowledge base or connected apps will be rolled back where possible.
</AlertDialogDescription> </AlertDialogDescription>
@ -114,8 +229,10 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) {
handleRevert(); handleRevert();
}} }}
disabled={isReverting} disabled={isReverting}
className="gap-1.5"
> >
{isReverting ? "Reverting…" : "Revert"} {isReverting && <Spinner size="xs" />}
Revert
</AlertDialogAction> </AlertDialogAction>
</AlertDialogFooter> </AlertDialogFooter>
</AlertDialogContent> </AlertDialogContent>
@ -123,18 +240,49 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) {
); );
} }
const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ /**
toolCallId, * Compact tool-call card.
toolName, *
argsText, * shadcn composition note: we intentionally use ``Card`` as a visual
result, * frame WITHOUT ``CardHeader / CardContent``. The full composition's
status, * ``p-6`` padding doesn't fit a compact collapsible header that IS the
}) => { * trigger; using ``Card`` alone preserves the rounded border, shadow,
const [isExpanded, setIsExpanded] = useState(false); * and ``bg-card`` token (semantic colors) without forcing a layout
* that doesn't fit. All status colors use semantic tokens no manual
* dark-mode overrides, no raw hex.
*/
const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
const { toolCallId, toolName, argsText, result, status } = props;
// ``langchainToolCallId`` is a SurfSense-specific extension the
// streaming pipeline attaches to the tool-call content part so
// the Revert button can resolve its ``AgentActionLog`` row even
// when only the LC id is known. assistant-ui's
// ``ToolCallMessagePartProps`` doesn't list it, but the runtime
// spreads ``{...part}`` so the prop reaches us at runtime.
const langchainToolCallId = (props as { langchainToolCallId?: string }).langchainToolCallId;
const isCancelled = status?.type === "incomplete" && status.reason === "cancelled"; const isCancelled = status?.type === "incomplete" && status.reason === "cancelled";
const isError = status?.type === "incomplete" && status.reason === "error"; const isError = status?.type === "incomplete" && status.reason === "error";
const isRunning = status?.type === "running" || status?.type === "requires-action"; const isRunning = status?.type === "running" || status?.type === "requires-action";
/*
Per-card expansion state. Initial value is ``isRunning`` so a
card streaming in mounts already-expanded (no flash of
collapsed expanded on first paint), while a card loaded from
history (status="complete") mounts collapsed. The useEffect
below keeps this in lockstep with this card's own ``isRunning``
when it transitions: false true auto-expands (e.g. a tool
that re-runs after edit), true false auto-collapses once the
tool finishes. Because the dep is per-card ``isRunning`` and
not the chat-level streaming flag, sibling cards on the same
assistant turn each manage their own expansion independently.
Once ``isRunning`` is false the user controls expansion via
``onOpenChange``.
*/
const [isExpanded, setIsExpanded] = useState(isRunning);
useEffect(() => {
setIsExpanded(isRunning);
}, [isRunning]);
const errorData = status?.type === "incomplete" ? status.error : undefined; const errorData = status?.type === "incomplete" ? status.error : undefined;
const serializedError = useMemo( const serializedError = useMemo(
() => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null), () => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null),
@ -160,108 +308,207 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
: serializedError : serializedError
: null; : null;
const Icon = getToolIcon(toolName);
const displayName = getToolDisplayName(toolName); const displayName = getToolDisplayName(toolName);
const subtitle = errorReason ?? cancelledReason;
return ( return (
<div <Card
className={cn( className={cn(
"my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none", "my-4 max-w-lg overflow-hidden",
isCancelled && "opacity-60", isCancelled && "opacity-60",
isError && "border-destructive/20 bg-destructive/5" isError && "border-destructive/30"
)} )}
> >
<button {/*
type="button" ``group`` lets the chevron (rendered as a sibling of the
onClick={() => setIsExpanded((prev) => !prev)} main trigger button) read the Collapsible Root's
className="flex w-full items-center gap-3 px-5 py-4 text-left transition-colors hover:bg-muted/50 focus:outline-none focus-visible:outline-none" ``data-[state=open]`` for rotation. The Collapsible is
fully controlled via ``isExpanded`` the useEffect
above syncs it to ``isRunning`` so the card auto-opens
while a tool streams in and auto-collapses once it
finishes. We deliberately DON'T pass ``disabled`` so
both triggers stay clickable; ``onOpenChange`` is wired
to a setter that no-ops while ``isRunning`` (see
``handleOpenChange`` below) which keeps the card pinned
open mid-stream without losing keyboard / pointer
affordance the moment streaming ends.
*/}
<Collapsible
className="group"
open={isExpanded}
onOpenChange={(next) => {
// Block manual collapse while the tool is still
// streaming — otherwise a stray click on either
// trigger would close the card and hide the live
// ``argsText`` panel mid-run. After streaming the
// user has full control again.
if (isRunning) return;
setIsExpanded(next);
}}
> >
<div {/*
className={cn( Header row: main trigger on the left (icon + title
"flex size-8 shrink-0 items-center justify-center rounded-lg", col), Revert + chevron-trigger on the right as
isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10" siblings of the main trigger. The chevron is wrapped
)} in its OWN ``CollapsibleTrigger`` (Radix supports
> multiple triggers per Root) so clicking the chevron
{isError ? ( toggles the same state as clicking the title row.
<XCircleIcon className="size-4 text-destructive" /> The Revert button stays a separate AlertDialog
) : isCancelled ? ( trigger and stops propagation in its onClick so it
<XCircleIcon className="size-4 text-muted-foreground" /> doesn't toggle the collapsible while opening the
) : isRunning ? ( confirm dialog. Keeping these as flat siblings
<Icon className="size-4 text-primary animate-pulse" /> rather than nesting Revert / chevron inside the
) : ( title trigger avoids invalid HTML
<CheckIcon className="size-4 text-primary" /> (button-in-button) and lets the Revert button
)} render in BOTH the collapsed and expanded states.
</div> */}
<div className="flex items-stretch transition-colors hover:bg-muted/50">
<CollapsibleTrigger asChild>
<button
type="button"
className={cn(
"flex flex-1 min-w-0 items-center gap-3 py-4 pl-5 pr-2 text-left",
// Inset ring — Card's ``overflow-hidden`` would
// clip an ``offset-2`` ring; ``ring-inset``
// paints inside the button box.
"focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset",
"disabled:cursor-default"
)}
>
<div
className={cn(
"flex size-8 shrink-0 items-center justify-center rounded-lg",
isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10"
)}
>
{isError ? (
<XCircleIcon className="size-4 text-destructive" />
) : isCancelled ? (
<XCircleIcon className="size-4 text-muted-foreground" />
) : isRunning ? (
<Spinner size="sm" className="text-primary" />
) : (
<CheckIcon className="size-4 text-primary" />
)}
</div>
<div className="flex-1 min-w-0"> <div className="flex flex-1 min-w-0 flex-col gap-0.5">
<p <div className="flex items-center gap-2">
className={cn( <p
"text-sm font-semibold", className={cn(
isError "text-sm font-semibold truncate",
? "text-destructive" isCancelled && "text-muted-foreground line-through",
: isCancelled isError && "text-destructive"
? "text-muted-foreground line-through" )}
: "text-foreground" >
)} {displayName}
> </p>
{isRunning {isRunning && <Badge variant="secondary">Running</Badge>}
? displayName {isError && <Badge variant="destructive">Failed</Badge>}
: isCancelled {isCancelled && <Badge variant="outline">Cancelled</Badge>}
? `Cancelled: ${displayName}` </div>
: isError {subtitle && (
? `Failed: ${displayName}` <p
: displayName} className={cn(
</p> "text-xs truncate",
{isRunning && <p className="text-xs text-muted-foreground mt-0.5">Working</p>} isError ? "text-destructive/80" : "text-muted-foreground"
{cancelledReason && ( )}
<p className="text-xs text-muted-foreground mt-0.5 truncate">{cancelledReason}</p> >
)} {subtitle}
{errorReason && ( </p>
<p className="text-xs text-destructive/80 mt-0.5 truncate">{errorReason}</p> )}
)} </div>
</div> </button>
</CollapsibleTrigger>
{!isRunning && ( {/*
<div className="shrink-0 text-muted-foreground"> Right-side controls. The Revert button is
{isExpanded ? ( visible whenever the matching action is
<ChevronDownIcon className="size-4" /> reversible including the collapsed state
) : ( but ``ToolCardRevertButton`` itself returns
<ChevronUpIcon className="size-4" /> ``null`` while a tool is still running because
)} no action-log row exists yet, so it doesn't
need an explicit ``isRunning`` gate here.
*/}
<div className="flex shrink-0 items-center gap-2 pl-2 pr-5">
<ToolCardRevertButton
toolCallId={toolCallId}
toolName={toolName}
langchainToolCallId={langchainToolCallId}
/>
<CollapsibleTrigger asChild>
<button
type="button"
aria-label={isExpanded ? "Collapse details" : "Expand details"}
className={cn(
"flex size-7 shrink-0 items-center justify-center rounded-md",
"text-muted-foreground hover:bg-muted hover:text-foreground",
"focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset",
"disabled:cursor-default"
)}
>
<ChevronDownIcon
className={cn(
"size-4 transition-transform duration-200",
"group-data-[state=open]:rotate-180"
)}
/>
</button>
</CollapsibleTrigger>
</div> </div>
)} </div>
</button>
{isExpanded && !isRunning && ( {/*
<> CollapsibleContent body auto-open while streaming
<div className="mx-5 h-px bg-border/50" /> (see ``open`` prop above) so the live ``argsText``
<div className="px-5 py-3 space-y-3"> streams into the Inputs panel directly, no need for
{argsText && ( a separate "Live input" panel. Native
<div> ``overflow-auto`` instead of ``ScrollArea`` because
<p className="text-xs font-medium text-muted-foreground mb-1">Inputs</p> Radix's Viewport can let content bleed past
<pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all"> ``max-h-*`` in dynamic flex layouts. ``min-w-0`` on
{argsText} the column wrappers guarantees ``break-all`` wraps
</pre> correctly within the bounded ``max-w-lg`` Card.
*/}
<CollapsibleContent>
<Separator />
<div className="flex flex-col gap-3 px-5 py-3">
{(argsText || isRunning) && (
<div className="flex flex-col gap-1 min-w-0">
<p className="text-xs font-medium text-muted-foreground">Inputs</p>
<div className="max-h-48 overflow-auto rounded-md bg-muted/40">
{argsText ? (
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
{argsText}
</pre>
) : (
// Bridges the brief gap between
// ``tool-input-start`` (creates the
// card, ``argsText`` undefined) and
// the first ``tool-input-delta``.
<p className="px-3 py-2 text-xs italic text-muted-foreground">
Waiting for input
</p>
)}
</div>
</div> </div>
)} )}
{!isCancelled && result !== undefined && ( {!isCancelled && result !== undefined && (
<> <>
<div className="h-px bg-border/30" /> <Separator />
<div> <div className="flex flex-col gap-1 min-w-0">
<p className="text-xs font-medium text-muted-foreground mb-1">Result</p> <p className="text-xs font-medium text-muted-foreground">Result</p>
<pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all"> <div className="max-h-64 overflow-auto rounded-md bg-muted/40">
{typeof result === "string" ? result : serializedResult} <pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
</pre> {typeof result === "string" ? result : serializedResult}
</pre>
</div>
</div> </div>
</> </>
)} )}
<div className="flex justify-end">
<ToolCardRevertButton toolCallId={toolCallId} />
</div>
</div> </div>
</> </CollapsibleContent>
)} </Collapsible>
</div> </Card>
); );
}; };

View file

@ -22,6 +22,7 @@ import {
addToolCall, addToolCall,
appendReasoning, appendReasoning,
appendText, appendText,
appendToolInputDelta,
buildContentForUI, buildContentForUI,
type ContentPartsState, type ContentPartsState,
endReasoning, endReasoning,
@ -188,6 +189,10 @@ export function FreeChatPage() {
); );
}; };
const scheduleFlush = () => batcher.schedule(flushMessages); const scheduleFlush = () => batcher.schedule(flushMessages);
const forceFlush = () => {
scheduleFlush();
batcher.flush();
};
try { try {
for await (const parsed of readSSEStream(response)) { for await (const parsed of readSSEStream(response)) {
@ -225,13 +230,20 @@ export function FreeChatPage() {
false, false,
parsed.langchainToolCallId parsed.langchainToolCallId
); );
batcher.flush(); forceFlush();
break; break;
case "tool-input-available": case "tool-input-delta":
appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta);
scheduleFlush();
break;
case "tool-input-available": {
const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2);
if (toolCallIndices.has(parsed.toolCallId)) { if (toolCallIndices.has(parsed.toolCallId)) {
updateToolCall(contentPartsState, parsed.toolCallId, { updateToolCall(contentPartsState, parsed.toolCallId, {
args: parsed.input || {}, args: parsed.input || {},
argsText: finalArgsText,
langchainToolCallId: parsed.langchainToolCallId, langchainToolCallId: parsed.langchainToolCallId,
}); });
} else { } else {
@ -244,16 +256,20 @@ export function FreeChatPage() {
false, false,
parsed.langchainToolCallId parsed.langchainToolCallId
); );
updateToolCall(contentPartsState, parsed.toolCallId, {
argsText: finalArgsText,
});
} }
batcher.flush(); forceFlush();
break; break;
}
case "tool-output-available": case "tool-output-available":
updateToolCall(contentPartsState, parsed.toolCallId, { updateToolCall(contentPartsState, parsed.toolCallId, {
result: parsed.output, result: parsed.output,
langchainToolCallId: parsed.langchainToolCallId, langchainToolCallId: parsed.langchainToolCallId,
}); });
batcher.flush(); forceFlush();
break; break;
case "data-thinking-step": { case "data-thinking-step": {

View file

@ -1,7 +1,13 @@
import { z } from "zod"; import { z } from "zod";
/** /**
* Raw message from database (real-time sync) * Raw message from database (real-time sync).
*
* ``turn_id`` is included so consumers (e.g. ``convertToThreadMessage``)
* can populate ``metadata.custom.chatTurnId`` on the
* ``ThreadMessageLike`` even after the live-collab Zero re-sync. The
* inline Revert button's ``(chat_turn_id, tool_name, position)``
* fallback in tool-fallback.tsx depends on it.
*/ */
export const rawMessage = z.object({ export const rawMessage = z.object({
id: z.number(), id: z.number(),
@ -10,6 +16,7 @@ export const rawMessage = z.object({
content: z.unknown(), content: z.unknown(),
author_id: z.string().nullable(), author_id: z.string().nullable(),
created_at: z.string(), created_at: z.string(),
turn_id: z.string().nullable().optional(),
}); });
export type RawMessage = z.infer<typeof rawMessage>; export type RawMessage = z.infer<typeof rawMessage>;

View file

@ -0,0 +1,416 @@
"use client";
import { type QueryClient, useQuery } from "@tanstack/react-query";
import { useCallback, useEffect, useMemo, useRef } from "react";
import {
type AgentAction,
type AgentActionListResponse,
agentActionsApiService,
} from "@/lib/apis/agent-actions-api.service";
// =============================================================================
// DIAGNOSTIC LOGGING — gated behind a single switch. Flip ``RevertDebug``
// to ``true`` to trace the full SSE → cache → card → button pipeline in
// the browser console. Off by default so we don't spam production. The
// infrastructure stays in place because the underlying id-mismatch
// failure mode is rare-but-real and surfaces only at runtime.
// =============================================================================
const RevertDebug = false;
const dbg = (...args: unknown[]) => {
if (RevertDebug && typeof window !== "undefined") {
// eslint-disable-next-line no-console
console.log("[RevertDebug]", ...args);
}
};
/**
* Unified store for ``AgentActionLog`` rows scoped to one thread.
*
* Replaces the previous SSE side-channel atom mess
* (``agentActionByLcIdAtom`` / ``agentActionByToolCallIdAtom`` /
* ``agentActionsByChatTurnIdAtom``) and the standalone hydration hook.
* One react-query cache entry is now the single source of truth for:
*
* * the inline Revert button on every tool-call card
* * the per-turn "Revert turn" button under each assistant message
* * the edit-from-position pre-flight that decides whether to show
* the confirmation dialog
* * the agent-actions sheet
*
* The cache is hydrated by ``GET /threads/{id}/actions`` (sized to
* 200, the server max) and updated incrementally by helpers that turn
* SSE events / revert RPC responses into ``setQueryData`` mutations.
* That keeps the card and the sheet in lockstep on every code path
* page reload, navigation, live stream, post-stream reversibility flip,
* and explicit revert clicks.
*/
export const ACTION_LOG_PAGE_SIZE = 200;
/** Stable react-query key for the per-thread action list. */
export function agentActionsQueryKey(threadId: number | null) {
return threadId !== null
? (["agent-actions", threadId] as const)
: (["agent-actions", "none"] as const);
}
/** Subset of the SSE ``data-action-log`` payload we care about. */
export interface ActionLogSseEvent {
id: number;
lc_tool_call_id: string | null;
chat_turn_id: string | null;
tool_name: string;
reversible: boolean;
reverse_descriptor_present: boolean;
error: boolean;
created_at: string | null;
}
/**
* Append or upsert a freshly-emitted ``AgentActionLog`` row into the
* thread-scoped query cache.
*
* The SSE payload is a strict subset of ``AgentAction``; missing
* fields (``args``, ``reverse_descriptor``, ``user_id``) are filled
* with ``null`` placeholders. The next refetch (sheet open, user
* focus, route stale) backfills them but the inline Revert button
* only reads the fields the SSE payload carries, so it lights up
* immediately.
*/
export function applyActionLogSse(
queryClient: QueryClient,
threadId: number,
searchSpaceId: number,
event: ActionLogSseEvent
): void {
dbg("applyActionLogSse: incoming SSE event", {
threadId,
searchSpaceId,
event,
});
queryClient.setQueryData<AgentActionListResponse>(
agentActionsQueryKey(threadId),
(prev) => {
const placeholder: AgentAction = {
id: event.id,
thread_id: threadId,
user_id: null,
search_space_id: searchSpaceId,
tool_name: event.tool_name,
args: null,
result_id: null,
reversible: event.reversible,
reverse_descriptor: event.reverse_descriptor_present ? {} : null,
error: event.error ? {} : null,
reverse_of: null,
reverted_by_action_id: null,
is_revert_action: false,
tool_call_id: event.lc_tool_call_id,
chat_turn_id: event.chat_turn_id,
created_at: event.created_at ?? new Date().toISOString(),
};
if (!prev) {
return {
items: [placeholder],
total: 1,
page: 0,
page_size: ACTION_LOG_PAGE_SIZE,
has_more: false,
};
}
const existingIdx = prev.items.findIndex((a) => a.id === event.id);
if (existingIdx >= 0) {
const merged = [...prev.items];
const existing = merged[existingIdx];
if (existing) {
merged[existingIdx] = {
...existing,
reversible: event.reversible,
tool_call_id: event.lc_tool_call_id ?? existing.tool_call_id,
chat_turn_id: event.chat_turn_id ?? existing.chat_turn_id,
};
}
dbg("applyActionLogSse: merged into existing entry", {
id: event.id,
tool_call_id: merged[existingIdx]?.tool_call_id,
reversible: merged[existingIdx]?.reversible,
});
return { ...prev, items: merged };
}
dbg("applyActionLogSse: appended new placeholder", {
id: event.id,
tool_call_id: placeholder.tool_call_id,
tool_name: placeholder.tool_name,
reversible: placeholder.reversible,
cacheSizeAfter: prev.items.length + 1,
});
// REST returns newest-first — keep that ordering when
// the server eventually refetches by prepending.
return {
...prev,
items: [placeholder, ...prev.items],
total: prev.total + 1,
};
}
);
}
/**
* Apply a post-SAVEPOINT reversibility flip
* (``data-action-log-updated`` SSE event) to the cache.
*/
export function applyActionLogUpdatedSse(
queryClient: QueryClient,
threadId: number,
id: number,
reversible: boolean
): void {
dbg("applyActionLogUpdatedSse: reversibility flip", {
threadId,
id,
reversible,
});
queryClient.setQueryData<AgentActionListResponse>(
agentActionsQueryKey(threadId),
(prev) => {
if (!prev) {
dbg("applyActionLogUpdatedSse: NO prev cache for thread; flip dropped", {
threadId,
id,
});
return prev;
}
let mutated = false;
const items = prev.items.map((a) => {
if (a.id !== id) return a;
mutated = true;
return { ...a, reversible };
});
if (!mutated) {
dbg("applyActionLogUpdatedSse: id not in cache; flip dropped", {
threadId,
id,
cacheSize: prev.items.length,
cacheIds: prev.items.map((a) => a.id),
});
}
return mutated ? { ...prev, items } : prev;
}
);
}
/**
* Optimistically mark ``id`` as reverted.
*
* Used by the inline / per-turn Revert button immediately after the
* server returns success so the UI flips to "Reverted" without
* waiting for a refetch. ``newActionId`` is the id of the new
* ``is_revert_action`` row the server inserted; pass ``null`` if the
* server didn't return it.
*/
export function markActionRevertedInCache(
queryClient: QueryClient,
threadId: number,
id: number,
newActionId: number | null
): void {
queryClient.setQueryData<AgentActionListResponse>(
agentActionsQueryKey(threadId),
(prev) => {
if (!prev) return prev;
let mutated = false;
const items = prev.items.map((a) => {
if (a.id !== id) return a;
mutated = true;
// ``-1`` is a sentinel meaning "we know it was reverted
// but the server didn't tell us the new row's id".
return {
...a,
reverted_by_action_id: newActionId ?? -1,
};
});
return mutated ? { ...prev, items } : prev;
}
);
}
/**
* Apply a batch of revert results (per-turn revert response) to the
* cache. Anything in the ``reverted`` / ``already_reverted`` buckets
* gets its ``reverted_by_action_id`` set; other rows are left alone.
*/
export function applyRevertTurnResultsToCache(
queryClient: QueryClient,
threadId: number,
entries: Array<{ id: number; newActionId: number | null }>
): void {
if (entries.length === 0) return;
queryClient.setQueryData<AgentActionListResponse>(
agentActionsQueryKey(threadId),
(prev) => {
if (!prev) return prev;
const lookup = new Map(entries.map((e) => [e.id, e.newActionId]));
let mutated = false;
const items = prev.items.map((a) => {
if (!lookup.has(a.id)) return a;
mutated = true;
const newActionId = lookup.get(a.id) ?? null;
return { ...a, reverted_by_action_id: newActionId ?? -1 };
});
return mutated ? { ...prev, items } : prev;
}
);
}
/**
* Read-side hook used by the card, the turn button, the sheet, and
* the edit-from-position pre-flight.
*
* Returns the raw query state plus convenience selectors so consumers
* don't reach into ``data.items`` directly. ``enabled`` is the only
* knob pass ``false`` to keep the query dormant when the consumer
* doesn't yet have a thread id.
*/
export function useAgentActionsQuery(
threadId: number | null,
options: { enabled?: boolean } = {}
) {
const enabled = (options.enabled ?? true) && threadId !== null;
const query = useQuery({
queryKey: agentActionsQueryKey(threadId),
queryFn: async () => {
dbg("useAgentActionsQuery: REST fetch START", {
threadId,
pageSize: ACTION_LOG_PAGE_SIZE,
});
const res = await agentActionsApiService.listForThread(threadId as number, {
page: 0,
pageSize: ACTION_LOG_PAGE_SIZE,
});
dbg("useAgentActionsQuery: REST fetch DONE", {
threadId,
total: res.total,
returned: res.items.length,
items: res.items.map((a) => ({
id: a.id,
tool_name: a.tool_name,
tool_call_id: a.tool_call_id,
reversible: a.reversible,
reverted_by_action_id: a.reverted_by_action_id,
is_revert_action: a.is_revert_action,
})),
});
return res;
},
enabled,
staleTime: 15 * 1000,
});
const items = useMemo(() => query.data?.items ?? [], [query.data]);
// Index ``items`` once per change so the lookups below are O(1)
// instead of O(N) per card per render. With the cache sized to 200
// rows and many tool cards visible at once, the unindexed scan was
// the hottest path on every assistant text-delta. (Vercel React
// rule ``js-index-maps`` / ``js-set-map-lookups``.)
const byToolCallId = useMemo(() => {
const m = new Map<string, AgentAction>();
for (const a of items) {
if (a.tool_call_id) m.set(a.tool_call_id, a);
}
return m;
}, [items]);
// Pre-grouped + pre-sorted (oldest-first, the order the agent
// actually executed them in) so the (chat_turn_id, tool_name,
// position) fallback in ``tool-fallback.tsx`` is also O(1) per
// card. Excludes ``is_revert_action`` rows so the position index
// matches the agent's original execution order.
const byTurnAndTool = useMemo(() => {
const m = new Map<string, AgentAction[]>();
for (const a of items) {
if (!a.chat_turn_id || a.is_revert_action) continue;
const key = `${a.chat_turn_id}::${a.tool_name}`;
const bucket = m.get(key);
if (bucket) bucket.push(a);
else m.set(key, [a]);
}
for (const bucket of m.values()) {
bucket.sort(
(a, b) =>
new Date(a.created_at).getTime() - new Date(b.created_at).getTime()
);
}
return m;
}, [items]);
// Snapshot the cache shape when its size changes — easiest way to
// spot when the cache is empty or stale at the moment a card
// mounts. Tracked on a ref so we don't re-run the diff on
// reference-equal cache reads.
const lastSnapshotRef = useRef<{ threadId: number | null; size: number } | null>(null);
useEffect(() => {
const last = lastSnapshotRef.current;
if (!last || last.threadId !== threadId || last.size !== items.length) {
dbg("useAgentActionsQuery: cache snapshot", {
threadId,
enabled,
itemCount: items.length,
itemKeys: items.slice(0, 8).map((a) => ({
id: a.id,
tool_name: a.tool_name,
tool_call_id: a.tool_call_id,
chat_turn_id: a.chat_turn_id,
reversible: a.reversible,
})),
});
lastSnapshotRef.current = { threadId, size: items.length };
}
}, [threadId, enabled, items]);
const findByToolCallId = useCallback(
(toolCallId: string | null | undefined): AgentAction | null => {
if (!toolCallId) return null;
const found = byToolCallId.get(toolCallId) ?? null;
if (!found && items.length > 0) {
dbg("findByToolCallId: MISS", {
queriedToolCallId: toolCallId,
itemCount: items.length,
availableToolCallIds: Array.from(byToolCallId.keys()),
});
}
return found;
},
[byToolCallId, items.length]
);
const findByChatTurnId = useCallback(
(chatTurnId: string | null | undefined): AgentAction[] => {
if (!chatTurnId) return [];
// Per-turn aggregation is uncommon enough (only the
// "Revert turn" button uses it) that re-scanning is fine;
// indexing it would just bloat memory.
return items.filter((a) => a.chat_turn_id === chatTurnId);
},
[items]
);
const findByChatTurnAndTool = useCallback(
(
chatTurnId: string | null | undefined,
toolName: string | null | undefined
): AgentAction[] => {
if (!chatTurnId || !toolName) return [];
return byTurnAndTool.get(`${chatTurnId}::${toolName}`) ?? [];
},
[byTurnAndTool]
);
return {
...query,
items,
findByToolCallId,
findByChatTurnId,
findByChatTurnAndTool,
};
}

View file

@ -31,6 +31,14 @@ export function useMessagesSync(
content: msg.content, content: msg.content,
author_id: msg.authorId ?? null, author_id: msg.authorId ?? null,
created_at: new Date(msg.createdAt).toISOString(), created_at: new Date(msg.createdAt).toISOString(),
// Forward the per-turn correlation id so post-stream Zero
// re-syncs preserve ``metadata.custom.chatTurnId`` on the
// converted ``ThreadMessageLike``. Without this the inline
// Revert button's ``(chat_turn_id, tool_name, position)``
// fallback breaks the moment Zero overwrites the messages
// state after a live stream completes (see
// ``handleSyncedMessagesUpdate`` in the chat page).
turn_id: msg.turnId ?? null,
})); }));
onMessagesUpdateRef.current(mapped); onMessagesUpdateRef.current(mapped);

View file

@ -16,6 +16,23 @@ export type ContentPart =
toolName: string; toolName: string;
args: Record<string, unknown>; args: Record<string, unknown>;
result?: unknown; result?: unknown;
/**
* Live / finalized JSON text for the tool's input arguments.
*
* - During streaming: accumulated partial JSON text from
* ``tool-input-delta`` events (may be invalid JSON
* mid-stream). assistant-ui's argsText parser tolerates
* invalid JSON gracefully (changelog 0.7.32 / 0.7.78).
* - On completion (``tool-input-available``): replaced with
* ``JSON.stringify(input, null, 2)`` so the post-stream
* card renders pretty-printed JSON instead of the
* model's possibly-fragmented formatting.
*
* Per assistant-ui ``ThreadMessageLike`` precedence
* (changelog 0.11.6 ``d318c83``), when ``argsText`` is
* supplied it wins over ``JSON.stringify(args)``.
*/
argsText?: string;
/** /**
* Authoritative LangChain ``tool_call.id`` propagated by the backend * Authoritative LangChain ``tool_call.id`` propagated by the backend
* via ``langchainToolCallId`` on tool-input-start/available and * via ``langchainToolCallId`` on tool-input-start/available and
@ -282,12 +299,22 @@ export function findToolCallIdByLcId(
export function updateToolCall( export function updateToolCall(
state: ContentPartsState, state: ContentPartsState,
toolCallId: string, toolCallId: string,
update: { args?: Record<string, unknown>; result?: unknown; langchainToolCallId?: string } update: {
args?: Record<string, unknown>;
argsText?: string;
result?: unknown;
langchainToolCallId?: string;
}
): void { ): void {
const index = state.toolCallIndices.get(toolCallId); const index = state.toolCallIndices.get(toolCallId);
if (index !== undefined && state.contentParts[index]?.type === "tool-call") { if (index !== undefined && state.contentParts[index]?.type === "tool-call") {
const tc = state.contentParts[index] as ContentPart & { type: "tool-call" }; const tc = state.contentParts[index] as ContentPart & { type: "tool-call" };
if (update.args) tc.args = update.args; if (update.args) tc.args = update.args;
// ``!== undefined`` (NOT a truthy check): an explicit empty
// string CAN clear, and a finalization with
// ``JSON.stringify({}, null, 2) === "{}"`` (truthy but
// represents an empty-input call) still applies.
if (update.argsText !== undefined) tc.argsText = update.argsText;
if (update.result !== undefined) tc.result = update.result; if (update.result !== undefined) tc.result = update.result;
// Only backfill langchainToolCallId if not already set — the // Only backfill langchainToolCallId if not already set — the
// authoritative ``on_tool_end`` value should override an earlier // authoritative ``on_tool_end`` value should override an earlier
@ -299,6 +326,25 @@ export function updateToolCall(
} }
} }
/**
* Append a streamed args-delta chunk to the active tool call's
* ``argsText``. No-ops when no card has been registered yet for the
* given ``toolCallId`` (the matching ``tool-input-start`` either lost
* the wire race or this id never had a card either way the deltas
* have nowhere safe to land).
*/
export function appendToolInputDelta(
state: ContentPartsState,
toolCallId: string,
delta: string
): void {
const idx = state.toolCallIndices.get(toolCallId);
if (idx === undefined) return;
const tc = state.contentParts[idx];
if (tc?.type !== "tool-call") return;
tc.argsText = (tc.argsText ?? "") + delta;
}
function _hasInterruptResult(part: ContentPart): boolean { function _hasInterruptResult(part: ContentPart): boolean {
if (part.type !== "tool-call") return false; if (part.type !== "tool-call") return false;
const r = (part as { result?: unknown }).result; const r = (part as { result?: unknown }).result;
@ -371,6 +417,18 @@ export type SSEEvent =
/** Authoritative LangChain ``tool_call.id``. Optional. */ /** Authoritative LangChain ``tool_call.id``. Optional. */
langchainToolCallId?: string; langchainToolCallId?: string;
} }
| {
/**
* Live tool-call argument delta. Concatenated into
* ``argsText`` on the matching ``tool-call`` content part
* by ``appendToolInputDelta``. parity_v2 only the legacy
* code path emits ``tool-input-available`` without prior
* deltas.
*/
type: "tool-input-delta";
toolCallId: string;
inputTextDelta: string;
}
| { | {
type: "tool-input-available"; type: "tool-input-available";
toolCallId: string; toolCallId: string;

View file

@ -8,6 +8,13 @@ export const newChatMessageTable = table("new_chat_messages")
threadId: number().from("thread_id"), threadId: number().from("thread_id"),
authorId: string().optional().from("author_id"), authorId: string().optional().from("author_id"),
createdAt: number().from("created_at"), createdAt: number().from("created_at"),
// Per-turn correlation id sourced from ``configurable.turn_id``
// at streaming time. Required by the inline Revert button's
// (chat_turn_id, tool_name, position) fallback in tool-fallback.tsx
// — without it the live-collab Zero sync would clobber the
// metadata we set during streaming and the button would vanish
// the moment Zero re-syncs after the stream finishes.
turnId: string().optional().from("turn_id"),
}) })
.primaryKey("id"); .primaryKey("id");