diff --git a/surfsense_backend/app/services/new_streaming_service.py b/surfsense_backend/app/services/new_streaming_service.py index 5dbae91c5..3531d37af 100644 --- a/surfsense_backend/app/services/new_streaming_service.py +++ b/surfsense_backend/app/services/new_streaming_service.py @@ -595,8 +595,17 @@ class VercelStreamingService: Format the start of tool input streaming. Args: - tool_call_id: The unique tool call identifier (synthetic, derived - from LangGraph ``run_id`` so the frontend has a stable card id). + tool_call_id: The unique tool call identifier. May be EITHER the + synthetic ``call_`` id derived from LangGraph + ``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2`` + OFF, or the unmatched-fallback path under parity_v2) OR + the authoritative LangChain ``tool_call.id`` (parity_v2 + path: when the provider streams ``tool_call_chunks`` we + register the ``index`` and reuse the lc-id as the card + id so live ``tool-input-delta`` events can be routed + without a downstream join). Either way, the same id is + preserved across ``tool-input-start`` / ``-delta`` / + ``-available`` / ``tool-output-available`` for one call. tool_name: The name of the tool being called. langchain_tool_call_id: Optional authoritative LangChain ``tool_call.id``. When set, surfaces as diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 1493c4326..c94945bb1 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -338,6 +338,42 @@ def _log_file_contract(stage: str, result: StreamResult, **extra: Any) -> None: ) +def _legacy_match_lc_id( + pending_tool_call_chunks: list[dict[str, Any]], + tool_name: str, + run_id: str, + lc_tool_call_id_by_run: dict[str, str], +) -> str | None: + """Best-effort match a buffered ``tool_call_chunk`` to a tool name. + + Pure extract of the legacy in-line match used at ``on_tool_start`` for + parity_v2-OFF and unmatched (chunk path didn't register an index for + this call) tools. Pops the next id-bearing chunk whose ``name`` + matches ``tool_name`` (or any id-bearing chunk as a fallback) and + returns its id. Mutates ``pending_tool_call_chunks`` and + ``lc_tool_call_id_by_run`` in place. + """ + matched_idx: int | None = None + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("name") == tool_name and tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + for idx, tcc in enumerate(pending_tool_call_chunks): + if tcc.get("id"): + matched_idx = idx + break + if matched_idx is None: + return None + matched = pending_tool_call_chunks.pop(matched_idx) + candidate = matched.get("id") + if isinstance(candidate, str) and candidate: + if run_id: + lc_tool_call_id_by_run[run_id] = candidate + return candidate + return None + + async def _stream_agent_events( agent: Any, config: dict[str, Any], @@ -403,10 +439,28 @@ async def _stream_agent_events( # ``tool_call_chunks`` from ``on_chat_model_stream``, key them by # name, and pop the next unconsumed entry at ``on_tool_start``. The # authoritative id is later filled in at ``on_tool_end`` from - # ``ToolMessage.tool_call_id``. + # ``ToolMessage.tool_call_id``. Under parity_v2 we ALSO short-circuit + # this list for chunks that already registered into ``index_to_meta`` + # below — so this list is reserved for the parity_v2-OFF / unmatched + # fallback path only and never re-pops a chunk we already streamed. pending_tool_call_chunks: list[dict[str, Any]] = [] lc_tool_call_id_by_run: dict[str, str] = {} + # parity_v2 only: live tool-call argument streaming. ``index_to_meta`` + # is keyed by the chunk's ``index`` field — LangChain + # ``ToolCallChunk``s for the same call share an index but only the + # first chunk carries id+name (subsequent ones are id=None, + # name=None, args=""). We register an index when both id and + # name are observed on a chunk (per ToolCallChunk semantics they + # arrive together on the first chunk), then route every later chunk + # at that index to the same ``ui_id`` as a ``tool-input-delta``. + # ``ui_tool_call_id_by_run`` maps LangGraph ``run_id`` to the + # ``ui_id`` used for that call's ``tool-input-start`` so the matching + # ``tool-output-available`` (emitted from ``on_tool_end``) lands on + # the same card. + index_to_meta: dict[int, dict[str, str]] = {} + ui_tool_call_id_by_run: dict[str, str] = {} + # Per-tool-end mutable cache for the LangChain tool_call_id resolved # at ``on_tool_end``. ``_emit_tool_output`` reads this so every # ``format_tool_output_available`` call automatically carries the @@ -452,13 +506,6 @@ async def _stream_agent_events( continue parts = _extract_chunk_parts(chunk) - # Accumulate any tool_call_chunks for best-effort - # correlation with ``on_tool_start`` below. We don't emit - # anything here; the matching is done at tool-start time. - if parity_v2 and parts["tool_call_chunks"]: - for tcc in parts["tool_call_chunks"]: - pending_tool_call_chunks.append(tcc) - reasoning_delta = parts["reasoning"] text_delta = parts["text"] @@ -504,6 +551,71 @@ async def _stream_agent_events( yield streaming_service.format_text_delta(current_text_id, text_delta) accumulated_text += text_delta + # Live tool-call argument streaming. Runs AFTER text/reasoning + # processing so chunks containing both stay in their natural + # wire order (text → text-end → tool-input-start). Active + # text/reasoning are closed inside the registration branch + # before ``tool-input-start`` so the frontend sees a clean + # part boundary even when providers interleave. + if parity_v2 and parts["tool_call_chunks"]: + for tcc in parts["tool_call_chunks"]: + idx = tcc.get("index") + + # Register this index when we first see id+name + # TOGETHER. Per LangChain ToolCallChunk semantics the + # first chunk for a tool call carries both fields + # together; later chunks have id=None, name=None and + # only ``args``. Requiring BOTH keeps wire + # ``tool-input-start`` always carrying a real + # toolName (assistant-ui's typed tool-part dispatch + # keys off it). + if idx is not None and idx not in index_to_meta: + lc_id = tcc.get("id") + name = tcc.get("name") + if lc_id and name: + ui_id = lc_id + + # Close active text/reasoning so wire + # ordering stays clean even on providers + # that interleave text and tool-call chunks + # within the same stream window. + if current_text_id is not None: + yield streaming_service.format_text_end(current_text_id) + current_text_id = None + if current_reasoning_id is not None: + yield streaming_service.format_reasoning_end( + current_reasoning_id + ) + current_reasoning_id = None + + index_to_meta[idx] = { + "ui_id": ui_id, + "lc_id": lc_id, + "name": name, + } + yield streaming_service.format_tool_input_start( + ui_id, + name, + langchain_tool_call_id=lc_id, + ) + + # Emit args delta for any chunk at a registered + # index (including idless continuations). Once an + # index is owned by ``index_to_meta`` we DO NOT + # append to ``pending_tool_call_chunks`` — that list + # is reserved for the parity_v2-OFF / unmatched + # fallback path so it never re-pops chunks already + # consumed here (skip-append). + meta = index_to_meta.get(idx) if idx is not None else None + if meta: + args_chunk = tcc.get("args") or "" + if args_chunk: + yield streaming_service.format_tool_input_delta( + meta["ui_id"], args_chunk + ) + else: + pending_tool_call_chunks.append(tcc) + elif event_type == "on_tool_start": active_tool_depth += 1 tool_name = event.get("name", "unknown_tool") @@ -834,44 +946,65 @@ async def _stream_agent_events( status="in_progress", ) - tool_call_id = ( - f"call_{run_id[:32]}" - if run_id - else streaming_service.generate_tool_call_id() - ) - - # Best-effort attach the LangChain ``tool_call_id``. We - # pop the first chunk in ``pending_tool_call_chunks`` whose - # name matches; if none match (the chunked args may not yet - # carry a ``name`` field, or the model skipped the chunked - # form) we leave ``langchainToolCallId`` unset for now and - # fill it in authoritatively at ``on_tool_end`` from - # ``ToolMessage.tool_call_id``. - langchain_tool_call_id: str | None = None - if parity_v2 and pending_tool_call_chunks: - matched_idx: int | None = None - for idx, tcc in enumerate(pending_tool_call_chunks): - if tcc.get("name") == tool_name and tcc.get("id"): - matched_idx = idx + # Resolve the card identity. If the chunk-emission loop + # already registered an ``index`` for this tool call (parity_v2 + # path), reuse the same ui_id so the card sees: + # tool-input-start → deltas… → tool-input-available → + # tool-output-available all keyed by lc_id. Otherwise fall + # back to the synthetic ``call_`` id and the legacy + # best-effort match against ``pending_tool_call_chunks``. + matched_meta: dict[str, str] | None = None + if parity_v2: + # FIFO over indices 0,1,2…; first unassigned same-name + # match wins. Handles parallel same-name calls (e.g. two + # write_file calls) deterministically as long as the + # model interleaves on_tool_start in the same order it + # streamed the args. + taken_ui_ids = set(ui_tool_call_id_by_run.values()) + for meta in index_to_meta.values(): + if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids: + matched_meta = meta break - if matched_idx is None: - for idx, tcc in enumerate(pending_tool_call_chunks): - if tcc.get("id"): - matched_idx = idx - break - if matched_idx is not None: - matched = pending_tool_call_chunks.pop(matched_idx) - candidate = matched.get("id") - if isinstance(candidate, str) and candidate: - langchain_tool_call_id = candidate - if run_id: - lc_tool_call_id_by_run[run_id] = candidate - yield streaming_service.format_tool_input_start( - tool_call_id, - tool_name, - langchain_tool_call_id=langchain_tool_call_id, - ) + tool_call_id: str + langchain_tool_call_id: str | None = None + if matched_meta is not None: + tool_call_id = matched_meta["ui_id"] + langchain_tool_call_id = matched_meta["lc_id"] + # ``tool-input-start`` already fired during chunk + # emission — skip the duplicate. No pruning is needed + # because the chunk-emission loop intentionally never + # appends registered-index chunks to + # ``pending_tool_call_chunks`` (skip-append). + if run_id: + lc_tool_call_id_by_run[run_id] = matched_meta["lc_id"] + else: + tool_call_id = ( + f"call_{run_id[:32]}" + if run_id + else streaming_service.generate_tool_call_id() + ) + # Legacy fallback: parity_v2 OFF, or parity_v2 ON but the + # provider didn't stream tool_call_chunks for this call + # (no index registered). Run the existing best-effort + # match BEFORE emitting start so we still attach an + # authoritative ``langchainToolCallId`` when possible. + if parity_v2: + langchain_tool_call_id = _legacy_match_lc_id( + pending_tool_call_chunks, + tool_name, + run_id, + lc_tool_call_id_by_run, + ) + yield streaming_service.format_tool_input_start( + tool_call_id, + tool_name, + langchain_tool_call_id=langchain_tool_call_id, + ) + + if run_id: + ui_tool_call_id_by_run[run_id] = tool_call_id + # Sanitize tool_input: strip runtime-injected non-serializable # values (e.g. LangChain ToolRuntime) before sending over SSE. if isinstance(tool_input, dict): @@ -924,7 +1057,15 @@ async def _stream_agent_events( result.write_succeeded = True result.verification_succeeded = True - tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown" + # Look up the SAME card id used at on_tool_start (either the + # parity_v2 lc-id-derived ui_id or the legacy synthetic + # ``call_``) so the output event always lands on the + # same card as start/delta/available. Fallback preserves the + # legacy synthetic shape for parity_v2-OFF / unknown-run paths. + tool_call_id = ui_tool_call_id_by_run.get( + run_id, + f"call_{run_id[:32]}" if run_id else "call_unknown", + ) original_step_id = tool_step_ids.get( run_id, f"{step_prefix}-unknown-{run_id[:8]}" ) @@ -935,17 +1076,22 @@ async def _stream_agent_events( # at ``on_tool_start`` time (kept in ``lc_tool_call_id_by_run``) # if the output isn't a ToolMessage. The value is stored in # ``current_lc_tool_call_id`` so ``_emit_tool_output`` - # picks it up for every output emit below. Stays None when - # parity_v2 is off so legacy emit paths are untouched. + # picks it up for every output emit below. + # + # Emitted in BOTH parity_v2 and legacy modes: the chat tool + # card needs the LangChain id to match against the + # ``data-action-log`` SSE event (keyed by ``lc_tool_call_id``) + # so the inline Revert button can light up. Reading + # ``raw_output.tool_call_id`` is a cheap, non-mutating attribute + # access that is safe regardless of feature-flag state. current_lc_tool_call_id["value"] = None - if parity_v2: - authoritative = getattr(raw_output, "tool_call_id", None) - if isinstance(authoritative, str) and authoritative: - current_lc_tool_call_id["value"] = authoritative - if run_id: - lc_tool_call_id_by_run[run_id] = authoritative - elif run_id and run_id in lc_tool_call_id_by_run: - current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id] + authoritative = getattr(raw_output, "tool_call_id", None) + if isinstance(authoritative, str) and authoritative: + current_lc_tool_call_id["value"] = authoritative + if run_id: + lc_tool_call_id_by_run[run_id] = authoritative + elif run_id and run_id in lc_tool_call_id_by_run: + current_lc_tool_call_id["value"] = lc_tool_call_id_by_run[run_id] if tool_name == "read_file": yield streaming_service.format_thinking_step( diff --git a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py index 7f32bf456..1263a5fe1 100644 --- a/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py +++ b/surfsense_backend/tests/unit/tasks/chat/test_extract_chunk_parts.py @@ -183,3 +183,46 @@ class TestDefensive: assert out["text"] == "" assert out["reasoning"] == "" assert out["tool_call_chunks"] == [] + + +class TestIdlessContinuationChunks: + """Per LangChain ``ToolCallChunk`` semantics, the FIRST chunk for a + tool call carries id+name; later chunks for the same call have + ``id=None, name=None`` and only ``args`` + ``index``. Live tool-call + argument streaming relies on those idless continuation chunks + flowing through ``_extract_chunk_parts`` UNTOUCHED so the upstream + chunk-emission loop can still route them by ``index``. + """ + + def test_idless_continuation_chunk_preserved_verbatim(self) -> None: + chunk = _FakeChunk( + tool_call_chunks=[ + {"id": None, "name": None, "args": '_path":"/x"}', "index": 0} + ] + ) + out = _extract_chunk_parts(chunk) + assert len(out["tool_call_chunks"]) == 1 + tcc = out["tool_call_chunks"][0] + assert tcc.get("id") is None + assert tcc.get("name") is None + assert tcc.get("args") == '_path":"/x"}' + assert tcc.get("index") == 0 + + def test_first_then_idless_sequence_preserves_index(self) -> None: + """Both chunks for the same call share an ``index`` key — the + index-routing loop in ``stream_new_chat`` depends on it.""" + first = _FakeChunk( + tool_call_chunks=[ + {"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0} + ] + ) + cont = _FakeChunk( + tool_call_chunks=[ + {"id": None, "name": None, "args": '_path":"/x"}', "index": 0} + ] + ) + out_first = _extract_chunk_parts(first) + out_cont = _extract_chunk_parts(cont) + assert out_first["tool_call_chunks"][0]["index"] == 0 + assert out_cont["tool_call_chunks"][0]["index"] == 0 + assert out_cont["tool_call_chunks"][0].get("id") is None diff --git a/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py new file mode 100644 index 000000000..9258d5cfe --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/test_tool_input_streaming.py @@ -0,0 +1,527 @@ +"""Unit tests for live tool-call argument streaming. + +Pins the wire format that ``_stream_agent_events`` emits when +``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start`` → +``tool-input-delta``... → ``tool-input-available`` → ``tool-output-available`` +all keyed by the same LangChain ``tool_call.id``. + +Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and +``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to +``_stream_agent_events`` so we exercise them via the public wire output. + +These tests also lock in the legacy / parity_v2-OFF behaviour so the +synthetic ``call_`` shape stays stable for older clients. +""" + +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator +from dataclasses import dataclass, field +from typing import Any + +import pytest + +import app.tasks.chat.stream_new_chat as stream_module +from app.agents.new_chat.feature_flags import AgentFeatureFlags +from app.services.new_streaming_service import VercelStreamingService +from app.tasks.chat.stream_new_chat import ( + StreamResult, + _legacy_match_lc_id, + _stream_agent_events, +) + +pytestmark = pytest.mark.unit + + +@dataclass +class _FakeChunk: + """Minimal stand-in for ``AIMessageChunk``.""" + + content: Any = "" + additional_kwargs: dict[str, Any] = field(default_factory=dict) + tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + + +@dataclass +class _FakeToolMessage: + """Stand-in for ``ToolMessage`` returned by ``on_tool_end``.""" + + content: Any + tool_call_id: str | None = None + + +class _FakeAgentState: + """Stand-in for ``StateSnapshot`` returned by ``aget_state``.""" + + def __init__(self) -> None: + # Empty values keeps the cloud-fallback safety-net branch a no-op, + # and an empty ``tasks`` list keeps the post-stream interrupt + # check a no-op too. + self.values: dict[str, Any] = {} + self.tasks: list[Any] = [] + + +class _FakeAgent: + """Replays a list of ``astream_events`` events.""" + + def __init__(self, events: list[dict[str, Any]]) -> None: + self._events = events + + async def astream_events( # type: ignore[no-untyped-def] + self, _input_data: Any, *, config: dict[str, Any], version: str + ) -> AsyncGenerator[dict[str, Any], None]: + del config, version # unused, contract-compatible + for ev in self._events: + yield ev + + async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState: + # Called once after astream_events drains so the cloud-fallback + # safety net can inspect staged filesystem work. The fake stays + # empty so the safety net is a no-op. + return _FakeAgentState() + + +def _model_stream( + *, + text: str = "", + reasoning: str = "", + tool_call_chunks: list[dict[str, Any]] | None = None, + tags: list[str] | None = None, +) -> dict[str, Any]: + return ( + { + "event": "on_chat_model_stream", + "tags": tags or [], + "data": { + "chunk": _FakeChunk( + content=text, + tool_call_chunks=list(tool_call_chunks or []), + ) + }, + # reasoning piggybacks via additional_kwargs path; if needed, + # override content to a typed-block list. Most tests just check + # tool_call_chunks routing so this is fine. + } + if not reasoning + else { + "event": "on_chat_model_stream", + "tags": tags or [], + "data": { + "chunk": _FakeChunk( + content=text, + additional_kwargs={"reasoning_content": reasoning}, + tool_call_chunks=list(tool_call_chunks or []), + ) + }, + } + ) + + +def _tool_start( + *, + name: str, + run_id: str, + input_payload: dict[str, Any] | None = None, +) -> dict[str, Any]: + return { + "event": "on_tool_start", + "name": name, + "run_id": run_id, + "data": {"input": input_payload or {}}, + } + + +def _tool_end( + *, + name: str, + run_id: str, + tool_call_id: str | None = None, + output: Any = "ok", +) -> dict[str, Any]: + return { + "event": "on_tool_end", + "name": name, + "run_id": run_id, + "data": { + "output": _FakeToolMessage( + content=json.dumps(output) if not isinstance(output, str) else output, + tool_call_id=tool_call_id, + ) + }, + } + + +@pytest.fixture +def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + stream_module, + "get_flags", + lambda: AgentFeatureFlags(enable_stream_parity_v2=True), + ) + + +@pytest.fixture +def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + stream_module, + "get_flags", + lambda: AgentFeatureFlags(enable_stream_parity_v2=False), + ) + + +async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Run ``_stream_agent_events`` against a fake agent and return the + SSE payloads (parsed JSON) it yielded. + """ + agent = _FakeAgent(events) + service = VercelStreamingService() + result = StreamResult() + config = {"configurable": {"thread_id": "test-thread"}} + sse_lines: list[str] = [] + async for sse in _stream_agent_events( + agent, config, {}, service, result, step_prefix="thinking" + ): + sse_lines.append(sse) + + parsed: list[dict[str, Any]] = [] + for line in sse_lines: + if not line.startswith("data: "): + continue + body = line[len("data: ") :].rstrip("\n") + if not body or body == "[DONE]": + continue + try: + parsed.append(json.loads(body)) + except json.JSONDecodeError: + continue + return parsed + + +def _types(payloads: list[dict[str, Any]]) -> list[str]: + return [p.get("type", "?") for p in payloads] + + +def _of_type(payloads: list[dict[str, Any]], type_name: str) -> list[dict[str, Any]]: + return [p for p in payloads if p.get("type") == type_name] + + +# --------------------------------------------------------------------------- +# Helper: ``_legacy_match_lc_id`` is a pure refactor; assert behaviour. +# --------------------------------------------------------------------------- + + +class TestLegacyMatch: + def test_pops_first_id_bearing_chunk_with_matching_name(self) -> None: + chunks: list[dict[str, Any]] = [ + {"id": "x1", "name": "ls"}, + {"id": "y1", "name": "write_file"}, + ] + runs: dict[str, str] = {} + result = _legacy_match_lc_id(chunks, "write_file", "run-1", runs) + assert result == "y1" + assert chunks == [{"id": "x1", "name": "ls"}] + assert runs == {"run-1": "y1"} + + def test_falls_back_to_any_id_bearing_when_name_mismatches(self) -> None: + chunks: list[dict[str, Any]] = [{"id": "anon", "name": None}] + runs: dict[str, str] = {} + out = _legacy_match_lc_id(chunks, "ls", "run-2", runs) + assert out == "anon" + assert chunks == [] + + def test_returns_none_when_no_id_bearing_chunk(self) -> None: + chunks: list[dict[str, Any]] = [{"id": None, "name": None}] + runs: dict[str, str] = {} + assert _legacy_match_lc_id(chunks, "ls", "run-3", runs) is None + assert chunks == [{"id": None, "name": None}] + assert runs == {} + + +# --------------------------------------------------------------------------- +# parity_v2 wire format tests. +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None: + """First chunk carries id+name; later idless chunks at the same + ``index`` merge into the SAME ``tool-input-start`` ui id and emit + one ``tool-input-delta`` per chunk.""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0} + ], + ), + _model_stream( + tool_call_chunks=[ + {"id": None, "name": None, "args": '_path":"/x"}', "index": 0} + ], + ), + _tool_start( + name="write_file", run_id="run-A", input_payload={"file_path": "/x"} + ), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"), + ] + + payloads = await _drain(events) + + starts = _of_type(payloads, "tool-input-start") + deltas = _of_type(payloads, "tool-input-delta") + available = _of_type(payloads, "tool-input-available") + output = _of_type(payloads, "tool-output-available") + + assert len(starts) == 1 + assert starts[0]["toolCallId"] == "lc-1" + assert starts[0]["toolName"] == "write_file" + assert starts[0]["langchainToolCallId"] == "lc-1" + + assert [d["inputTextDelta"] for d in deltas] == ['{"file', '_path":"/x"}'] + assert all(d["toolCallId"] == "lc-1" for d in deltas) + + assert len(available) == 1 + assert available[0]["toolCallId"] == "lc-1" + + assert len(output) == 1 + assert output[0]["toolCallId"] == "lc-1" + + +@pytest.mark.asyncio +async def test_two_interleaved_tool_calls_route_by_index( + parity_v2_on: None, +) -> None: + """Two same-name calls with distinct indices keep their deltas + routed to the right card.""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-A", "name": "write_file", "args": '{"a":1', "index": 0}, + {"id": "lc-B", "name": "write_file", "args": '{"b":2', "index": 1}, + ] + ), + _model_stream( + tool_call_chunks=[ + {"id": None, "name": None, "args": "}", "index": 0}, + {"id": None, "name": None, "args": "}", "index": 1}, + ] + ), + _tool_start(name="write_file", run_id="run-A", input_payload={"a": 1}), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-A"), + _tool_start(name="write_file", run_id="run-B", input_payload={"b": 2}), + _tool_end(name="write_file", run_id="run-B", tool_call_id="lc-B"), + ] + + payloads = await _drain(events) + + starts = _of_type(payloads, "tool-input-start") + deltas = _of_type(payloads, "tool-input-delta") + output = _of_type(payloads, "tool-output-available") + + assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"} + + by_id: dict[str, list[str]] = {"lc-A": [], "lc-B": []} + for d in deltas: + by_id[d["toolCallId"]].append(d["inputTextDelta"]) + assert by_id["lc-A"] == ['{"a":1', "}"] + assert by_id["lc-B"] == ['{"b":2', "}"] + + assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"} + + +@pytest.mark.asyncio +async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None: + """Whatever id ``tool-input-start`` chose must be the SAME id used + on ``tool-input-available`` AND ``tool-output-available``.""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-9", "name": "ls", "args": '{"path":"/"}', "index": 0} + ] + ), + _tool_start(name="ls", run_id="run-X", input_payload={"path": "/"}), + _tool_end(name="ls", run_id="run-X", tool_call_id="lc-9"), + ] + payloads = await _drain(events) + relevant = [ + p + for p in payloads + if p.get("type") + in {"tool-input-start", "tool-input-available", "tool-output-available"} + ] + assert {p["toolCallId"] for p in relevant} == {"lc-9"} + + +@pytest.mark.asyncio +async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None: + """When the chunk-emission loop already fired ``tool-input-start`` + for this run, ``on_tool_start`` MUST NOT emit a second one.""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-1", "name": "write_file", "args": "{}", "index": 0} + ] + ), + _tool_start(name="write_file", run_id="run-A", input_payload={}), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"), + ] + payloads = await _drain(events) + starts = _of_type(payloads, "tool-input-start") + assert len(starts) == 1 + assert starts[0]["toolCallId"] == "lc-1" + + +@pytest.mark.asyncio +async def test_active_text_closes_before_early_tool_input_start( + parity_v2_on: None, +) -> None: + """Streaming a text-delta then a tool-call chunk in subsequent + chunks: the wire MUST contain ``text-end`` before the FIRST + ``tool-input-start`` (clean part boundary on the frontend).""" + events = [ + _model_stream(text="Working on it"), + _model_stream( + tool_call_chunks=[ + {"id": "lc-1", "name": "write_file", "args": "{}", "index": 0} + ] + ), + _tool_start(name="write_file", run_id="run-A", input_payload={}), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"), + ] + types = _types(await _drain(events)) + text_end_idx = types.index("text-end") + start_idx = types.index("tool-input-start") + assert text_end_idx < start_idx + + +@pytest.mark.asyncio +async def test_mixed_text_and_tool_chunk_preserve_order( + parity_v2_on: None, +) -> None: + """One AIMessageChunk that carries BOTH ``text`` content AND + ``tool_call_chunks`` should emit the text delta FIRST, then close + text, then ``tool-input-start``+``tool-input-delta``.""" + events = [ + _model_stream( + text="I'll update it", + tool_call_chunks=[ + { + "id": "lc-1", + "name": "write_file", + "args": '{"file_path":"/x"}', + "index": 0, + } + ], + ), + _tool_start( + name="write_file", run_id="run-A", input_payload={"file_path": "/x"} + ), + _tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"), + ] + types = _types(await _drain(events)) + # text-start … text-delta … text-end … tool-input-start … tool-input-delta + assert types.index("text-start") < types.index("text-delta") + assert types.index("text-delta") < types.index("text-end") + assert types.index("text-end") < types.index("tool-input-start") + assert types.index("tool-input-start") < types.index("tool-input-delta") + + +@pytest.mark.asyncio +async def test_parity_v2_off_preserves_legacy_shape( + parity_v2_off: None, +) -> None: + """When the flag is OFF, no deltas are emitted and the ``toolCallId`` + is ``call_`` (NOT the lc id).""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0} + ] + ), + _tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}), + _tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"), + ] + payloads = await _drain(events) + + assert _of_type(payloads, "tool-input-delta") == [] + starts = _of_type(payloads, "tool-input-start") + assert len(starts) == 1 + assert starts[0]["toolCallId"].startswith("call_run-A") + # No ``langchainToolCallId`` propagation on ``tool-input-start`` in + # legacy mode (the start event fires before the ToolMessage is + # available, so we can't extract the authoritative LangChain id yet). + assert "langchainToolCallId" not in starts[0] + output = _of_type(payloads, "tool-output-available") + assert output[0]["toolCallId"].startswith("call_run-A") + # ``tool-output-available`` MUST carry ``langchainToolCallId`` even + # in legacy mode: the chat tool card uses it to backfill the + # LangChain id and join against the ``data-action-log`` SSE event + # (keyed by ``lc_tool_call_id``) so the inline Revert button can + # light up. Sourced from the returned ``ToolMessage.tool_call_id``, + # which is populated regardless of feature-flag state. + assert output[0]["langchainToolCallId"] == "lc-1" + + +@pytest.mark.asyncio +async def test_skip_append_prevents_stale_id_reuse( + parity_v2_on: None, +) -> None: + """Two same-name tools: the SECOND tool's ``langchainToolCallId`` + must NOT come from the first tool's chunk (``pending_tool_call_chunks`` + must stay empty for indexed-registered chunks).""" + events = [ + _model_stream( + tool_call_chunks=[ + {"id": "lc-A", "name": "write_file", "args": "{}", "index": 0}, + {"id": "lc-B", "name": "write_file", "args": "{}", "index": 1}, + ] + ), + _tool_start(name="write_file", run_id="run-1", input_payload={}), + _tool_end(name="write_file", run_id="run-1", tool_call_id="lc-A"), + _tool_start(name="write_file", run_id="run-2", input_payload={}), + _tool_end(name="write_file", run_id="run-2", tool_call_id="lc-B"), + ] + payloads = await _drain(events) + + starts = _of_type(payloads, "tool-input-start") + # Two distinct lc ids, each its own card. + assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"} + # Each tool-output-available landed on its respective card. + output = _of_type(payloads, "tool-output-available") + assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"} + + +@pytest.mark.asyncio +async def test_registration_waits_for_both_id_and_name( + parity_v2_on: None, +) -> None: + """An id-only chunk (no name yet) must NOT emit ``tool-input-start``.""" + events = [ + _model_stream( + tool_call_chunks=[{"id": "lc-1", "name": None, "args": "", "index": 0}] + ), + ] + payloads = await _drain(events) + assert _of_type(payloads, "tool-input-start") == [] + + +@pytest.mark.asyncio +async def test_unmatched_fallback_still_attaches_lc_id( + parity_v2_on: None, +) -> None: + """parity_v2 ON, but the provider didn't include an ``index``: the + legacy fallback path must still emit ``tool-input-start`` with the + matching ``langchainToolCallId``.""" + events = [ + # No index on the chunk → not registered into index_to_meta; + # falls through to ``pending_tool_call_chunks`` so the legacy + # match path can pop it at on_tool_start. + _model_stream(tool_call_chunks=[{"id": "lc-orphan", "name": "ls", "args": ""}]), + _tool_start(name="ls", run_id="run-1", input_payload={"path": "/"}), + _tool_end(name="ls", run_id="run-1", tool_call_id="lc-orphan"), + ] + payloads = await _drain(events) + starts = _of_type(payloads, "tool-input-start") + assert len(starts) == 1 + assert starts[0]["toolCallId"].startswith("call_run-1") + assert starts[0]["langchainToolCallId"] == "lc-orphan" diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index c2086e80a..e5ac61cd9 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -14,13 +14,6 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { z } from "zod"; import { disabledToolsAtom } from "@/atoms/agent-tools/agent-tools.atoms"; -import { - agentActionsByChatTurnIdAtom, - markAgentActionRevertedAtom, - resetAgentActionMapAtom, - updateAgentActionReversibleAtom, - upsertAgentActionAtom, -} from "@/atoms/chat/agent-actions.atom"; import { clearTargetCommentIdAtom, currentThreadAtom, @@ -55,6 +48,12 @@ import { type TokenUsageData, TokenUsageProvider, } from "@/components/assistant-ui/token-usage-context"; +import { + applyActionLogSse, + applyActionLogUpdatedSse, + markActionRevertedInCache, + useAgentActionsQuery, +} from "@/hooks/use-agent-actions-query"; import { useChatSessionStateSync } from "@/hooks/use-chat-session-state"; import { useMessagesSync } from "@/hooks/use-messages-sync"; import { getAgentFilesystemSelection } from "@/lib/agent-filesystem"; @@ -71,12 +70,12 @@ import { addToolCall, appendReasoning, appendText, + appendToolInputDelta, buildContentForPersistence, buildContentForUI, type ContentPartsState, endReasoning, FrameBatchedUpdater, - findToolCallIdByLcId, readSSEStream, type ThinkingStepData, type ToolUIGate, @@ -246,14 +245,6 @@ export default function NewChatPage() { const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom); const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom); - // Agent action log SSE side-channel. - const upsertAgentAction = useSetAtom(upsertAgentActionAtom); - const updateAgentActionReversible = useSetAtom(updateAgentActionReversibleAtom); - const markAgentActionReverted = useSetAtom(markAgentActionRevertedAtom); - const resetAgentActionMap = useSetAtom(resetAgentActionMapAtom); - // Chat-turn-keyed action map for the edit-from-position pre-flight - // that decides whether to show the confirmation dialog. - const agentActionsByChatTurnId = useAtomValue(agentActionsByChatTurnIdAtom); // Edit dialog state. Holds the message id being edited and // the (already extracted) regenerate args so we can resume the edit // after the user picks "revert all" / "continue" / "cancel". @@ -282,6 +273,11 @@ export default function NewChatPage() { content: unknown; author_id: string | null; created_at: string; + // Forwarded so ``convertToThreadMessage`` can rebuild the + // ``metadata.custom.chatTurnId`` on the + // ``ThreadMessageLike``. Required by the inline Revert + // button's per-turn fallback. + turn_id?: string | null; }[] ) => { if (isRunning) { @@ -314,6 +310,11 @@ export default function NewChatPage() { created_at: msg.created_at, author_display_name: member?.user_display_name ?? existingAuthor?.displayName ?? null, author_avatar_url: member?.user_avatar_url ?? existingAuthor?.avatarUrl ?? null, + // Forward the per-turn correlation id so the + // inline Revert button's ``(chat_turn_id, + // tool_name, position)`` fallback survives the + // post-stream Zero re-sync. + turn_id: msg.turn_id ?? null, }); }); }); @@ -330,6 +331,13 @@ export default function NewChatPage() { return Number.isNaN(parsed) ? 0 : parsed; }, [params.search_space_id]); + // Unified store for agent-action rows (the same react-query cache + // the agent-actions sheet, the inline Revert button, and the + // per-turn Revert button all read). Hydrates from + // ``GET /threads/{id}/actions`` and is updated incrementally by the + // SSE handlers + revert-batch results below — no atom side-channel. + const { items: agentActionItems } = useAgentActionsQuery(threadId); + // Extract chat_id from URL params const urlChatId = useMemo(() => { const id = params.chat_id; @@ -357,7 +365,8 @@ export default function NewChatPage() { clearPlanOwnerRegistry(); closeReportPanel(); closeEditorPanel(); - resetAgentActionMap(); + // Note: agent-action data is keyed by threadId in react-query so + // switching threads naturally swaps caches; no explicit reset. try { if (urlChatId > 0) { @@ -426,7 +435,6 @@ export default function NewChatPage() { removeChatTab, searchSpaceId, tokenUsageStore, - resetAgentActionMap, ]); // Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same) @@ -779,6 +787,15 @@ export default function NewChatPage() { ); }; const scheduleFlush = () => batcher.schedule(flushMessages); + // Force-flush helper: ``batcher.flush()`` is a no-op when + // ``dirty=false`` (e.g. a tool starts before any text + // streamed). ``scheduleFlush(); batcher.flush()`` sets + // the dirty bit FIRST so terminal events render + // promptly without the 50ms throttle delay. + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; for await (const parsed of readSSEStream(response)) { switch (parsed.type) { @@ -815,13 +832,23 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); - batcher.flush(); + forceFlush(); + break; + + case "tool-input-delta": + // High-frequency event: deltas can fire dozens + // of times per call, so use throttled + // scheduleFlush (NOT forceFlush) to coalesce. + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); break; case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, }); } else { @@ -834,8 +861,14 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); + // addToolCall doesn't accept argsText today; + // backfill via updateToolCall so the new card + // renders pretty-printed JSON. + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); } - batcher.flush(); + forceFlush(); break; } @@ -854,7 +887,7 @@ export default function NewChatPage() { } } } - batcher.flush(); + forceFlush(); break; } @@ -950,34 +983,17 @@ export default function NewChatPage() { } case "data-action-log": { - const al = parsed.data; - const matchedToolCallId = al.lc_tool_call_id - ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) - : null; - upsertAgentAction({ - action: { - id: al.id, - threadId: currentThreadId, - lcToolCallId: al.lc_tool_call_id, - chatTurnId: al.chat_turn_id, - toolName: al.tool_name, - reversible: al.reversible, - reverseDescriptorPresent: al.reverse_descriptor_present, - error: al.error, - revertedByActionId: null, - isRevertAction: false, - createdAt: al.created_at, - }, - toolCallId: matchedToolCallId, - }); + applyActionLogSse(queryClient, currentThreadId, searchSpaceId, parsed.data); break; } case "data-action-log-updated": { - updateAgentActionReversible({ - id: parsed.data.id, - reversible: parsed.data.reversible, - }); + applyActionLogUpdatedSse( + queryClient, + currentThreadId, + parsed.data.id, + parsed.data.reversible + ); break; } @@ -1179,6 +1195,15 @@ export default function NewChatPage() { toolName: String(p.toolName), args: (p.args as Record) ?? {}, result: p.result as unknown, + // Restore argsText so persisted pretty-printed + // JSON survives reloads (assistant-ui prefers + // supplied argsText over JSON.stringify(args)). + // langchainToolCallId restoration also fixes a + // pre-existing dropped-id bug on resume. + ...(typeof p.argsText === "string" ? { argsText: p.argsText } : {}), + ...(typeof p.langchainToolCallId === "string" + ? { langchainToolCallId: p.langchainToolCallId } + : {}), }); contentPartsState.currentTextPartIndex = -1; } else if (p.type === "data-thinking-steps") { @@ -1200,7 +1225,12 @@ export default function NewChatPage() { const editedAction = decisions[0].edited_action; for (const part of contentParts) { if (part.type === "tool-call" && part.toolName === editedAction.name) { - part.args = { ...part.args, ...editedAction.args }; + const mergedArgs = { ...part.args, ...editedAction.args }; + part.args = mergedArgs; + // Sync argsText so the rendered card shows the + // edited inputs — assistant-ui prefers caller- + // supplied argsText over JSON.stringify(args). + part.argsText = JSON.stringify(mergedArgs, null, 2); break; } } @@ -1256,6 +1286,10 @@ export default function NewChatPage() { ); }; const scheduleFlush = () => batcher.schedule(flushMessages); + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; for await (const parsed of readSSEStream(response)) { switch (parsed.type) { @@ -1292,13 +1326,20 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); - batcher.flush(); + forceFlush(); break; - case "tool-input-available": + case "tool-input-delta": + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + break; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, }); } else { @@ -1311,9 +1352,13 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); } - batcher.flush(); + forceFlush(); break; + } case "tool-output-available": updateToolCall(contentPartsState, parsed.toolCallId, { @@ -1321,7 +1366,7 @@ export default function NewChatPage() { langchainToolCallId: parsed.langchainToolCallId, }); markInterruptsCompleted(contentParts); - batcher.flush(); + forceFlush(); break; case "data-thinking-step": { @@ -1381,34 +1426,17 @@ export default function NewChatPage() { } case "data-action-log": { - const al = parsed.data; - const matchedToolCallId = al.lc_tool_call_id - ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) - : null; - upsertAgentAction({ - action: { - id: al.id, - threadId: resumeThreadId, - lcToolCallId: al.lc_tool_call_id, - chatTurnId: al.chat_turn_id, - toolName: al.tool_name, - reversible: al.reversible, - reverseDescriptorPresent: al.reverse_descriptor_present, - error: al.error, - revertedByActionId: null, - isRevertAction: false, - createdAt: al.created_at, - }, - toolCallId: matchedToolCallId, - }); + applyActionLogSse(queryClient, resumeThreadId, searchSpaceId, parsed.data); break; } case "data-action-log-updated": { - updateAgentActionReversible({ - id: parsed.data.id, - reversible: parsed.data.reversible, - }); + applyActionLogUpdatedSse( + queryClient, + resumeThreadId, + parsed.data.id, + parsed.data.reversible + ); break; } @@ -1502,6 +1530,11 @@ export default function NewChatPage() { return { ...part, args: decision.edited_action.args, // Update displayed args + // Sync argsText so the rendered card shows + // the edited inputs — assistant-ui prefers + // caller-supplied argsText over + // JSON.stringify(args). + argsText: JSON.stringify(decision.edited_action.args, null, 2), result: { ...(part.result as Record), __decided__: decisionType, @@ -1712,6 +1745,10 @@ export default function NewChatPage() { ); }; const scheduleFlush = () => batcher.schedule(flushMessages); + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; for await (const parsed of readSSEStream(response)) { switch (parsed.type) { @@ -1748,13 +1785,20 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); - batcher.flush(); + forceFlush(); break; - case "tool-input-available": + case "tool-input-delta": + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + break; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, }); } else { @@ -1767,9 +1811,13 @@ export default function NewChatPage() { false, parsed.langchainToolCallId ); + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); } - batcher.flush(); + forceFlush(); break; + } case "tool-output-available": updateToolCall(contentPartsState, parsed.toolCallId, { @@ -1786,7 +1834,7 @@ export default function NewChatPage() { } } } - batcher.flush(); + forceFlush(); break; case "data-thinking-step": { @@ -1802,34 +1850,21 @@ export default function NewChatPage() { } case "data-action-log": { - const al = parsed.data; - const matchedToolCallId = al.lc_tool_call_id - ? findToolCallIdByLcId(contentPartsState, al.lc_tool_call_id) - : null; - upsertAgentAction({ - action: { - id: al.id, - threadId, - lcToolCallId: al.lc_tool_call_id, - chatTurnId: al.chat_turn_id, - toolName: al.tool_name, - reversible: al.reversible, - reverseDescriptorPresent: al.reverse_descriptor_present, - error: al.error, - revertedByActionId: null, - isRevertAction: false, - createdAt: al.created_at, - }, - toolCallId: matchedToolCallId, - }); + if (threadId !== null) { + applyActionLogSse(queryClient, threadId, searchSpaceId, parsed.data); + } break; } case "data-action-log-updated": { - updateAgentActionReversible({ - id: parsed.data.id, - reversible: parsed.data.reversible, - }); + if (threadId !== null) { + applyActionLogUpdatedSse( + queryClient, + threadId, + parsed.data.id, + parsed.data.reversible + ); + } break; } @@ -1866,12 +1901,16 @@ export default function NewChatPage() { : `Reverted ${summary.reverted} downstream actions before regenerating.` ); } - for (const r of summary.results) { - if (r.status === "reverted" || r.status === "already_reverted") { - markAgentActionReverted({ - id: r.action_id, - newActionId: r.new_action_id ?? null, - }); + if (threadId !== null) { + for (const r of summary.results) { + if (r.status === "reverted" || r.status === "already_reverted") { + markActionRevertedInCache( + queryClient, + threadId, + r.action_id, + r.new_action_id ?? null + ); + } } } break; @@ -2019,16 +2058,26 @@ export default function NewChatPage() { const downstream = messages.slice(editedIndex + 1); downstreamTotalCount = downstream.length; const seenTurns = new Set(); + const downstreamTurnIds = new Set(); for (const m of downstream) { const meta = (m.metadata ?? {}) as { custom?: { chatTurnId?: string } }; const tid = meta.custom?.chatTurnId; if (!tid || seenTurns.has(tid)) continue; seenTurns.add(tid); - const turnActions = agentActionsByChatTurnId.get(tid) ?? []; - for (const a of turnActions) { - if (a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error) { - downstreamReversibleCount += 1; - } + downstreamTurnIds.add(tid); + } + // Source of truth: the unified react-query cache. Every + // action whose ``chat_turn_id`` belongs to the slice we're + // about to drop counts toward the prompt. + for (const a of agentActionItems) { + if (!a.chat_turn_id || !downstreamTurnIds.has(a.chat_turn_id)) continue; + if ( + a.reversible && + (a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) && + !a.is_revert_action && + (a.error === null || a.error === undefined) + ) { + downstreamReversibleCount += 1; } } } @@ -2052,7 +2101,7 @@ export default function NewChatPage() { downstreamTotalCount, }); }, - [handleRegenerate, messages, agentActionsByChatTurnId] + [handleRegenerate, messages, agentActionItems] ); const handleEditDialogChoice = useCallback( diff --git a/surfsense_web/atoms/chat/agent-actions.atom.ts b/surfsense_web/atoms/chat/agent-actions.atom.ts deleted file mode 100644 index 7830c8751..000000000 --- a/surfsense_web/atoms/chat/agent-actions.atom.ts +++ /dev/null @@ -1,194 +0,0 @@ -"use client"; - -import { atom } from "jotai"; - -/** - * Minimal per-row projection of ``AgentActionLog`` that the tool card - * needs to decide whether to render a Revert button. - * - * Fields are deliberately a subset of the full ``AgentAction`` so the - * SSE side-channel (``data-action-log`` / ``data-action-log-updated``) - * can populate them without depending on the REST endpoint - * ``GET /threads/.../actions`` (which 503s when - * ``SURFSENSE_ENABLE_ACTION_LOG`` is off). - */ -export interface AgentActionLite { - id: number; - threadId: number | null; - lcToolCallId: string | null; - chatTurnId: string | null; - toolName: string; - reversible: boolean; - reverseDescriptorPresent: boolean; - error: boolean; - revertedByActionId: number | null; - isRevertAction: boolean; - createdAt: string | null; -} - -/** - * Map keyed off the LangChain ``tool_call.id`` (mirrors ``ContentPart - * tool-call.langchainToolCallId``). - */ -export const agentActionByLcIdAtom = atom>(new Map()); - -/** - * Parallel map keyed off the synthetic chat-card ``toolCallId`` - * (``call_``) so ``ToolFallback`` (which only receives the - * synthetic id from assistant-ui) can join its card to the action log. - * - * Both maps are kept in sync by ``upsertAgentActionAtom``. - */ -export const agentActionByToolCallIdAtom = atom>(new Map()); - -/** - * Index keyed by ``chat_turn_id`` so the per-turn revert UI can answer - * "how many reversible actions does this assistant turn contain?" in - * O(1). Each entry's array is ordered by insertion (which - * for a single turn matches ``created_at`` because action-log writes - * happen synchronously). - */ -export const agentActionsByChatTurnIdAtom = atom>(new Map()); - -/** - * Action to upsert one ``AgentActionLite`` row. - * - * ``toolCallId`` is the synthetic card id (``call_`` from - * ``stream_new_chat.py``). When provided alongside ``lcToolCallId``, the - * action is indexed under BOTH ids so the tool card can perform the - * lookup without going via the streaming state. - */ -export const upsertAgentActionAtom = atom( - null, - (_get, set, payload: { action: AgentActionLite; toolCallId?: string | null }) => { - const { action, toolCallId } = payload; - const upsertInto = ( - prev: Map, - key: string - ): Map => { - const next = new Map(prev); - const existing = next.get(key); - next.set(key, { - ...action, - // Preserve the local "reverted" bookkeeping if a reversibility - // flip arrives AFTER the user already reverted via the REST - // route. We never want a stale ``reversible=true`` event to - // resurrect a Reverted card. - revertedByActionId: existing?.revertedByActionId ?? action.revertedByActionId, - isRevertAction: existing?.isRevertAction ?? action.isRevertAction, - }); - return next; - }; - if (action.lcToolCallId) { - set(agentActionByLcIdAtom, (prev) => upsertInto(prev, action.lcToolCallId as string)); - } - if (toolCallId) { - set(agentActionByToolCallIdAtom, (prev) => upsertInto(prev, toolCallId)); - } - if (action.chatTurnId) { - set(agentActionsByChatTurnIdAtom, (prev) => { - const next = new Map(prev); - const turnId = action.chatTurnId as string; - const existing = next.get(turnId) ?? []; - const priorEntry = existing.find((row) => row.id === action.id); - const merged: AgentActionLite = { - ...action, - revertedByActionId: priorEntry?.revertedByActionId ?? action.revertedByActionId, - isRevertAction: priorEntry?.isRevertAction ?? action.isRevertAction, - }; - const others = existing.filter((row) => row.id !== action.id); - next.set(turnId, [...others, merged]); - return next; - }); - } - } -); - -function mutateById( - prev: Map, - id: number, - mutator: (entry: AgentActionLite) => AgentActionLite -): Map { - let mutated = false; - const next = new Map(prev); - for (const [key, value] of next) { - if (value.id === id) { - next.set(key, mutator(value)); - mutated = true; - } - } - return mutated ? next : prev; -} - -function mutateByIdInTurnIndex( - prev: Map, - id: number, - mutator: (entry: AgentActionLite) => AgentActionLite -): Map { - let mutated = false; - const next = new Map(prev); - for (const [key, list] of next) { - let listMutated = false; - const updated = list.map((row) => { - if (row.id === id) { - listMutated = true; - return mutator(row); - } - return row; - }); - if (listMutated) { - next.set(key, updated); - mutated = true; - } - } - return mutated ? next : prev; -} - -/** - * Action to flip an existing entry's ``reversible`` flag, keyed by the - * AgentActionLog row id (the SSE ``data-action-log-updated`` payload - * does NOT carry ``lcToolCallId``). - */ -export const updateAgentActionReversibleAtom = atom( - null, - (_get, set, payload: { id: number; reversible: boolean }) => { - const apply = (entry: AgentActionLite): AgentActionLite => ({ - ...entry, - reversible: payload.reversible, - }); - set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply)); - set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply)); - set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply)); - } -); - -/** Action to mark an existing entry as reverted (post-revert call). */ -export const markAgentActionRevertedAtom = atom( - null, - (_get, set, payload: { id: number; newActionId: number | null }) => { - const apply = (entry: AgentActionLite): AgentActionLite => ({ - ...entry, - revertedByActionId: payload.newActionId ?? -1, - }); - set(agentActionByLcIdAtom, (prev) => mutateById(prev, payload.id, apply)); - set(agentActionByToolCallIdAtom, (prev) => mutateById(prev, payload.id, apply)); - set(agentActionsByChatTurnIdAtom, (prev) => mutateByIdInTurnIndex(prev, payload.id, apply)); - } -); - -/** Mark every action in a turn as reverted, given a list of (id, newActionId) pairs. */ -export const markAgentActionsRevertedBatchAtom = atom( - null, - (_get, set, payload: { entries: Array<{ id: number; newActionId: number | null }> }) => { - for (const entry of payload.entries) { - set(markAgentActionRevertedAtom, entry); - } - } -); - -/** Reset all maps (e.g. when the active thread changes). */ -export const resetAgentActionMapAtom = atom(null, (_get, set) => { - set(agentActionByLcIdAtom, new Map()); - set(agentActionByToolCallIdAtom, new Map()); - set(agentActionsByChatTurnIdAtom, new Map()); -}); diff --git a/surfsense_web/components/agent-action-log/action-log-sheet.tsx b/surfsense_web/components/agent-action-log/action-log-sheet.tsx index 68d2ffef3..32c25771a 100644 --- a/surfsense_web/components/agent-action-log/action-log-sheet.tsx +++ b/surfsense_web/components/agent-action-log/action-log-sheet.tsx @@ -1,9 +1,9 @@ "use client"; -import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { useQueryClient } from "@tanstack/react-query"; import { useAtom, useAtomValue } from "jotai"; import { Activity, RefreshCcw } from "lucide-react"; -import { useCallback, useMemo } from "react"; +import { useCallback } from "react"; import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom"; import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; import { Badge } from "@/components/ui/badge"; @@ -17,15 +17,12 @@ import { SheetTitle, } from "@/components/ui/sheet"; import { Skeleton } from "@/components/ui/skeleton"; -import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; +import { + agentActionsQueryKey, + useAgentActionsQuery, +} from "@/hooks/use-agent-actions-query"; import { ActionLogItem } from "./action-log-item"; -const ACTION_LOG_PAGE_SIZE = 50; - -function actionLogQueryKey(threadId: number) { - return ["agent-actions", threadId] as const; -} - function EmptyState() { return (
@@ -85,25 +82,17 @@ export function ActionLogSheet() { const threadId = state.threadId; - const { data, isLoading, isFetching, isError, error, refetch } = useQuery({ - queryKey: threadId !== null ? actionLogQueryKey(threadId) : ["agent-actions", "none"], - queryFn: () => - agentActionsApiService.listForThread(threadId as number, { - page: 0, - pageSize: ACTION_LOG_PAGE_SIZE, - }), - enabled: state.open && threadId !== null && actionLogEnabled, - staleTime: 15 * 1000, - }); + const { data, items, isLoading, isFetching, isError, error, refetch } = useAgentActionsQuery( + threadId, + { enabled: state.open && actionLogEnabled } + ); const handleRevertSuccess = useCallback(() => { if (threadId !== null) { - queryClient.invalidateQueries({ queryKey: actionLogQueryKey(threadId) }); + queryClient.invalidateQueries({ queryKey: agentActionsQueryKey(threadId) }); } }, [queryClient, threadId]); - const items = useMemo(() => data?.items ?? [], [data]); - return ( setState((s) => ({ ...s, open }))}> ([]); - // Subscribe ONLY to the slice of the global action map that belongs - // to ``chatTurnId``. Previously the button read the whole - // ``agentActionsByChatTurnIdAtom``, which meant every action - // upsert (one per tool call) re-rendered every Revert button on - // the page. With ``selectAtom`` we re-render only when our turn's - // list reference changes — and the upsert/mark atoms produce a - // fresh list reference for the affected turn only. - const sliceAtom = useMemo( - () => - selectAtom( - agentActionsByChatTurnIdAtom, - (turnIndex) => (chatTurnId ? turnIndex.get(chatTurnId) : undefined) ?? EMPTY_ACTIONS - ), - [chatTurnId] - ); - const actions = useAtomValue(sliceAtom); + const actions = useMemo(() => findByChatTurnId(chatTurnId), [findByChatTurnId, chatTurnId]); const reversibleCount = useMemo( () => actions.filter( - (a) => a.reversible && a.revertedByActionId === null && !a.isRevertAction && !a.error + (a) => + a.reversible && + (a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) && + !a.is_revert_action && + (a.error === null || a.error === undefined) ).length, [actions] ); - const totalCount = useMemo(() => actions.filter((a) => !a.isRevertAction).length, [actions]); + const totalCount = useMemo(() => actions.filter((a) => !a.is_revert_action).length, [actions]); if (!chatTurnId) return null; if (reversibleCount === 0) return null; - const threadId = session?.threadId; if (!threadId) return null; const handleRevertTurn = async () => { @@ -103,7 +87,7 @@ export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) { .filter((r) => r.status === "reverted" || r.status === "already_reverted") .map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null })); if (revertedEntries.length > 0) { - markRevertedBatch({ entries: revertedEntries }); + applyRevertTurnResultsToCache(queryClient, threadId, revertedEntries); } if (response.status === "ok") { toast.success( diff --git a/surfsense_web/components/assistant-ui/tool-fallback.tsx b/surfsense_web/components/assistant-ui/tool-fallback.tsx index cc7582695..66e2ebd4a 100644 --- a/surfsense_web/components/assistant-ui/tool-fallback.tsx +++ b/surfsense_web/components/assistant-ui/tool-fallback.tsx @@ -1,12 +1,12 @@ -import type { ToolCallMessagePartComponent } from "@assistant-ui/react"; -import { useAtomValue, useSetAtom } from "jotai"; -import { CheckIcon, ChevronDownIcon, ChevronUpIcon, RotateCcw, XCircleIcon } from "lucide-react"; -import { useMemo, useState } from "react"; -import { toast } from "sonner"; import { - agentActionByToolCallIdAtom, - markAgentActionRevertedAtom, -} from "@/atoms/chat/agent-actions.atom"; + type ToolCallMessagePartComponent, + useAuiState, +} from "@assistant-ui/react"; +import { useQueryClient } from "@tanstack/react-query"; +import { useAtomValue } from "jotai"; +import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react"; +import { useEffect, useMemo, useState } from "react"; +import { toast } from "sonner"; import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom"; import { DoomLoopApprovalToolUI, @@ -24,8 +24,17 @@ import { AlertDialogTitle, AlertDialogTrigger, } from "@/components/ui/alert-dialog"; +import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; -import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons"; +import { Card } from "@/components/ui/card"; +import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; +import { Separator } from "@/components/ui/separator"; +import { Spinner } from "@/components/ui/spinner"; +import { getToolDisplayName } from "@/contracts/enums/toolIcons"; +import { + markActionRevertedInCache, + useAgentActionsQuery, +} from "@/hooks/use-agent-actions-query"; import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service"; import { AppError } from "@/lib/error"; import { isInterruptResult } from "@/lib/hitl"; @@ -34,31 +43,128 @@ import { cn } from "@/lib/utils"; /** * Inline Revert button rendered on a tool card when the matching * ``AgentActionLog`` row is reversible and hasn't been reverted yet. - * Reads from the SSE side-channel atom keyed by the synthetic - * ``toolCallId`` so it lights up even when ``GET /threads/.../actions`` - * is gated behind ``SURFSENSE_ENABLE_ACTION_LOG=False`` (503). + * + * Reads from the unified ``useAgentActionsQuery`` cache — the SAME + * react-query cache the agent-actions sheet consumes. SSE events + * (``data-action-log`` / ``data-action-log-updated``) and + * ``POST /threads/{id}/revert/{id}`` responses both flow through the + * cache via ``setQueryData`` helpers, so the card and the sheet stay + * in lockstep on every code path: page reload, navigation, live + * stream, post-stream reversibility flip, and explicit revert clicks. + * + * Match key (in priority order): + * 1. ``a.tool_call_id === toolCallId`` — direct hit in parity_v2 when + * the model streamed ``tool_call_chunks`` so the card's synthetic + * id IS the LangChain id. + * 2. ``a.tool_call_id === langchainToolCallId`` — legacy mode (or + * parity_v2 with provider-side chunk emission) where the card's + * synthetic id is ``call_`` and the LangChain id is + * backfilled onto the part by ``tool-output-available``. + * 3. ``(chat_turn_id, tool_name, position-within-turn)`` — fallback + * for cards whose synthetic id is ``call_`` AND whose + * ``langchainToolCallId`` never got backfilled (provider emitted + * the tool_call as a single payload with no chunks AND streaming + * pre-dated the ``tool-output-available langchainToolCallId`` + * backfill, e.g. older threads). Reads the parent message's + * ``chatTurnId`` and ``content`` via ``useAuiState`` so we can + * match position-by-tool-name within the turn against the + * action_log rows the server returned in ``created_at`` order. */ -function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { +function ToolCardRevertButton({ + toolCallId, + toolName, + langchainToolCallId, +}: { + toolCallId: string; + toolName: string; + langchainToolCallId?: string; +}) { const session = useAtomValue(chatSessionStateAtom); - const actionMap = useAtomValue(agentActionByToolCallIdAtom); - const markReverted = useSetAtom(markAgentActionRevertedAtom); - const action = actionMap.get(toolCallId); + const threadId = session?.threadId ?? null; + const queryClient = useQueryClient(); + const { findByToolCallId, findByChatTurnAndTool } = useAgentActionsQuery(threadId); + + // Parent message metadata, read via the narrowest possible + // selectors so this card doesn't re-render on every text-delta of + // every other part in the same message during streaming. + // + // IMPORTANT — ``useAuiState`` re-renders the component whenever the + // returned slice's identity changes. Returning ``message?.content`` + // (an array) would re-render on every token because the runtime + // rebuilds the parts array. Returning a PRIMITIVE (the position + // number) lets ``useAuiState``'s ``Object.is`` check short-circuit + // when the position hasn't actually moved — which is the common + // case during text streaming, when only ``text``/``reasoning`` + // parts are mutating and the same-toolName tool-call ordering is + // stable. (See Vercel React rule ``rerender-defer-reads``.) + const chatTurnId = useAuiState(({ message }) => { + const meta = message?.metadata as { custom?: { chatTurnId?: string } } | undefined; + return meta?.custom?.chatTurnId ?? null; + }); + const positionInTurn = useAuiState(({ message }) => { + const content = message?.content; + if (!Array.isArray(content)) return -1; + let n = -1; + for (const part of content) { + if ( + part && + typeof part === "object" && + (part as { type?: string }).type === "tool-call" && + (part as { toolName?: string }).toolName === toolName + ) { + n += 1; + if ((part as { toolCallId?: string }).toolCallId === toolCallId) return n; + } + } + return -1; + }); + + const action = useMemo(() => { + // Tier 1 + 2: O(1) Map-backed direct id match. Covers + // ~all parity_v2 streams and any legacy stream that backfilled + // ``langchainToolCallId`` via ``tool-output-available``. + const direct = + findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId); + if (direct) return direct; + // Tier 3: position-within-turn fallback. Only kicks in when the + // card has a synthetic ``call_`` id AND no + // ``langchainToolCallId`` was ever backfilled — i.e. the tool + // was emitted as a single non-chunked payload AND streaming + // pre-dated the on_tool_end backfill. + if (!chatTurnId || positionInTurn < 0) return null; + const turnSameTool = findByChatTurnAndTool(chatTurnId, toolName); + return turnSameTool[positionInTurn] ?? null; + }, [ + findByToolCallId, + findByChatTurnAndTool, + toolCallId, + langchainToolCallId, + chatTurnId, + toolName, + positionInTurn, + ]); + const [isReverting, setIsReverting] = useState(false); const [confirmOpen, setConfirmOpen] = useState(false); if (!action) return null; if (!action.reversible) return null; - if (action.revertedByActionId !== null) return null; - if (action.isRevertAction) return null; - if (action.error) return null; - const threadId = session?.threadId; + if (action.reverted_by_action_id !== null && action.reverted_by_action_id !== undefined) + return null; + if (action.is_revert_action) return null; + if (action.error !== null && action.error !== undefined) return null; if (!threadId) return null; const handleRevert = async () => { setIsReverting(true); try { const response = await agentActionsApiService.revert(threadId, action.id); - markReverted({ id: action.id, newActionId: response.new_action_id ?? null }); + markActionRevertedInCache( + queryClient, + threadId, + action.id, + response.new_action_id ?? null + ); toast.success(response.message || "Action reverted."); } catch (err) { // 503 means revert is gated off on this deployment — hide the @@ -91,8 +197,17 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { e.stopPropagation(); setConfirmOpen(true); }} + disabled={isReverting} > - + {isReverting ? ( + // Spinner's typed props don't accept ``data-icon`` and + // it renders an , not an , so Button's + // auto-sizing rule doesn't apply. Bare spinner + + // Button's gap handle layout. + + ) : ( + + )} Revert @@ -101,7 +216,7 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { Revert this action? This will undo{" "} - {getToolDisplayName(action.toolName)} and add a + {getToolDisplayName(action.tool_name)} and add a new entry to the history. Your chat is preserved — only the changes the agent made to your knowledge base or connected apps will be rolled back where possible. @@ -114,8 +229,10 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { handleRevert(); }} disabled={isReverting} + className="gap-1.5" > - {isReverting ? "Reverting…" : "Revert"} + {isReverting && } + Revert @@ -123,18 +240,49 @@ function ToolCardRevertButton({ toolCallId }: { toolCallId: string }) { ); } -const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ - toolCallId, - toolName, - argsText, - result, - status, -}) => { - const [isExpanded, setIsExpanded] = useState(false); +/** + * Compact tool-call card. + * + * shadcn composition note: we intentionally use ``Card`` as a visual + * frame WITHOUT ``CardHeader / CardContent``. The full composition's + * ``p-6`` padding doesn't fit a compact collapsible header that IS the + * trigger; using ``Card`` alone preserves the rounded border, shadow, + * and ``bg-card`` token (semantic colors) without forcing a layout + * that doesn't fit. All status colors use semantic tokens — no manual + * dark-mode overrides, no raw hex. + */ +const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => { + const { toolCallId, toolName, argsText, result, status } = props; + // ``langchainToolCallId`` is a SurfSense-specific extension the + // streaming pipeline attaches to the tool-call content part so + // the Revert button can resolve its ``AgentActionLog`` row even + // when only the LC id is known. assistant-ui's + // ``ToolCallMessagePartProps`` doesn't list it, but the runtime + // spreads ``{...part}`` so the prop reaches us at runtime. + const langchainToolCallId = (props as { langchainToolCallId?: string }).langchainToolCallId; const isCancelled = status?.type === "incomplete" && status.reason === "cancelled"; const isError = status?.type === "incomplete" && status.reason === "error"; const isRunning = status?.type === "running" || status?.type === "requires-action"; + + /* + Per-card expansion state. Initial value is ``isRunning`` so a + card streaming in mounts already-expanded (no flash of + collapsed → expanded on first paint), while a card loaded from + history (status="complete") mounts collapsed. The useEffect + below keeps this in lockstep with this card's own ``isRunning`` + when it transitions: false → true auto-expands (e.g. a tool + that re-runs after edit), true → false auto-collapses once the + tool finishes. Because the dep is per-card ``isRunning`` and + not the chat-level streaming flag, sibling cards on the same + assistant turn each manage their own expansion independently. + Once ``isRunning`` is false the user controls expansion via + ``onOpenChange``. + */ + const [isExpanded, setIsExpanded] = useState(isRunning); + useEffect(() => { + setIsExpanded(isRunning); + }, [isRunning]); const errorData = status?.type === "incomplete" ? status.error : undefined; const serializedError = useMemo( () => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null), @@ -160,108 +308,207 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({ : serializedError : null; - const Icon = getToolIcon(toolName); const displayName = getToolDisplayName(toolName); + const subtitle = errorReason ?? cancelledReason; return ( -
- + - {!isRunning && ( -
- {isExpanded ? ( - - ) : ( - - )} + {/* + Right-side controls. The Revert button is + visible whenever the matching action is + reversible — including the collapsed state — + but ``ToolCardRevertButton`` itself returns + ``null`` while a tool is still running because + no action-log row exists yet, so it doesn't + need an explicit ``isRunning`` gate here. + */} +
+ + + +
- )} - +
- {isExpanded && !isRunning && ( - <> -
-
- {argsText && ( -
-

Inputs

-
-									{argsText}
-								
+ {/* + CollapsibleContent body — auto-open while streaming + (see ``open`` prop above) so the live ``argsText`` + streams into the Inputs panel directly, no need for + a separate "Live input" panel. Native + ``overflow-auto`` instead of ``ScrollArea`` because + Radix's Viewport can let content bleed past + ``max-h-*`` in dynamic flex layouts. ``min-w-0`` on + the column wrappers guarantees ``break-all`` wraps + correctly within the bounded ``max-w-lg`` Card. + */} + + +
+ {(argsText || isRunning) && ( +
+

Inputs

+
+ {argsText ? ( +
+											{argsText}
+										
+ ) : ( + // Bridges the brief gap between + // ``tool-input-start`` (creates the + // card, ``argsText`` undefined) and + // the first ``tool-input-delta``. +

+ Waiting for input… +

+ )} +
)} {!isCancelled && result !== undefined && ( <> -
-
-

Result

-
-										{typeof result === "string" ? result : serializedResult}
-									
+ +
+

Result

+
+
+											{typeof result === "string" ? result : serializedResult}
+										
+
)} -
- -
- - )} -
+ + + ); }; diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index bfdd613e2..05db99407 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -22,6 +22,7 @@ import { addToolCall, appendReasoning, appendText, + appendToolInputDelta, buildContentForUI, type ContentPartsState, endReasoning, @@ -146,6 +147,10 @@ export function FreeChatPage() { ); }; const scheduleFlush = () => batcher.schedule(flushMessages); + const forceFlush = () => { + scheduleFlush(); + batcher.flush(); + }; try { for await (const parsed of readSSEStream(response)) { @@ -183,13 +188,20 @@ export function FreeChatPage() { false, parsed.langchainToolCallId ); - batcher.flush(); + forceFlush(); break; - case "tool-input-available": + case "tool-input-delta": + appendToolInputDelta(contentPartsState, parsed.toolCallId, parsed.inputTextDelta); + scheduleFlush(); + break; + + case "tool-input-available": { + const finalArgsText = JSON.stringify(parsed.input ?? {}, null, 2); if (toolCallIndices.has(parsed.toolCallId)) { updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {}, + argsText: finalArgsText, langchainToolCallId: parsed.langchainToolCallId, }); } else { @@ -202,16 +214,20 @@ export function FreeChatPage() { false, parsed.langchainToolCallId ); + updateToolCall(contentPartsState, parsed.toolCallId, { + argsText: finalArgsText, + }); } - batcher.flush(); + forceFlush(); break; + } case "tool-output-available": updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output, langchainToolCallId: parsed.langchainToolCallId, }); - batcher.flush(); + forceFlush(); break; case "data-thinking-step": { diff --git a/surfsense_web/contracts/types/chat-messages.types.ts b/surfsense_web/contracts/types/chat-messages.types.ts index 0859f9f3b..ef16bb366 100644 --- a/surfsense_web/contracts/types/chat-messages.types.ts +++ b/surfsense_web/contracts/types/chat-messages.types.ts @@ -1,7 +1,13 @@ import { z } from "zod"; /** - * Raw message from database (real-time sync) + * Raw message from database (real-time sync). + * + * ``turn_id`` is included so consumers (e.g. ``convertToThreadMessage``) + * can populate ``metadata.custom.chatTurnId`` on the + * ``ThreadMessageLike`` even after the live-collab Zero re-sync. The + * inline Revert button's ``(chat_turn_id, tool_name, position)`` + * fallback in tool-fallback.tsx depends on it. */ export const rawMessage = z.object({ id: z.number(), @@ -10,6 +16,7 @@ export const rawMessage = z.object({ content: z.unknown(), author_id: z.string().nullable(), created_at: z.string(), + turn_id: z.string().nullable().optional(), }); export type RawMessage = z.infer; diff --git a/surfsense_web/hooks/use-agent-actions-query.ts b/surfsense_web/hooks/use-agent-actions-query.ts new file mode 100644 index 000000000..9a722fb2e --- /dev/null +++ b/surfsense_web/hooks/use-agent-actions-query.ts @@ -0,0 +1,416 @@ +"use client"; + +import { type QueryClient, useQuery } from "@tanstack/react-query"; +import { useCallback, useEffect, useMemo, useRef } from "react"; +import { + type AgentAction, + type AgentActionListResponse, + agentActionsApiService, +} from "@/lib/apis/agent-actions-api.service"; + +// ============================================================================= +// DIAGNOSTIC LOGGING — gated behind a single switch. Flip ``RevertDebug`` +// to ``true`` to trace the full SSE → cache → card → button pipeline in +// the browser console. Off by default so we don't spam production. The +// infrastructure stays in place because the underlying id-mismatch +// failure mode is rare-but-real and surfaces only at runtime. +// ============================================================================= +const RevertDebug = false; +const dbg = (...args: unknown[]) => { + if (RevertDebug && typeof window !== "undefined") { + // eslint-disable-next-line no-console + console.log("[RevertDebug]", ...args); + } +}; + +/** + * Unified store for ``AgentActionLog`` rows scoped to one thread. + * + * Replaces the previous SSE side-channel atom mess + * (``agentActionByLcIdAtom`` / ``agentActionByToolCallIdAtom`` / + * ``agentActionsByChatTurnIdAtom``) and the standalone hydration hook. + * One react-query cache entry is now the single source of truth for: + * + * * the inline Revert button on every tool-call card + * * the per-turn "Revert turn" button under each assistant message + * * the edit-from-position pre-flight that decides whether to show + * the confirmation dialog + * * the agent-actions sheet + * + * The cache is hydrated by ``GET /threads/{id}/actions`` (sized to + * 200, the server max) and updated incrementally by helpers that turn + * SSE events / revert RPC responses into ``setQueryData`` mutations. + * That keeps the card and the sheet in lockstep on every code path — + * page reload, navigation, live stream, post-stream reversibility flip, + * and explicit revert clicks. + */ + +export const ACTION_LOG_PAGE_SIZE = 200; + +/** Stable react-query key for the per-thread action list. */ +export function agentActionsQueryKey(threadId: number | null) { + return threadId !== null + ? (["agent-actions", threadId] as const) + : (["agent-actions", "none"] as const); +} + +/** Subset of the SSE ``data-action-log`` payload we care about. */ +export interface ActionLogSseEvent { + id: number; + lc_tool_call_id: string | null; + chat_turn_id: string | null; + tool_name: string; + reversible: boolean; + reverse_descriptor_present: boolean; + error: boolean; + created_at: string | null; +} + +/** + * Append or upsert a freshly-emitted ``AgentActionLog`` row into the + * thread-scoped query cache. + * + * The SSE payload is a strict subset of ``AgentAction``; missing + * fields (``args``, ``reverse_descriptor``, ``user_id``) are filled + * with ``null`` placeholders. The next refetch (sheet open, user + * focus, route stale) backfills them — but the inline Revert button + * only reads the fields the SSE payload carries, so it lights up + * immediately. + */ +export function applyActionLogSse( + queryClient: QueryClient, + threadId: number, + searchSpaceId: number, + event: ActionLogSseEvent +): void { + dbg("applyActionLogSse: incoming SSE event", { + threadId, + searchSpaceId, + event, + }); + queryClient.setQueryData( + agentActionsQueryKey(threadId), + (prev) => { + const placeholder: AgentAction = { + id: event.id, + thread_id: threadId, + user_id: null, + search_space_id: searchSpaceId, + tool_name: event.tool_name, + args: null, + result_id: null, + reversible: event.reversible, + reverse_descriptor: event.reverse_descriptor_present ? {} : null, + error: event.error ? {} : null, + reverse_of: null, + reverted_by_action_id: null, + is_revert_action: false, + tool_call_id: event.lc_tool_call_id, + chat_turn_id: event.chat_turn_id, + created_at: event.created_at ?? new Date().toISOString(), + }; + if (!prev) { + return { + items: [placeholder], + total: 1, + page: 0, + page_size: ACTION_LOG_PAGE_SIZE, + has_more: false, + }; + } + const existingIdx = prev.items.findIndex((a) => a.id === event.id); + if (existingIdx >= 0) { + const merged = [...prev.items]; + const existing = merged[existingIdx]; + if (existing) { + merged[existingIdx] = { + ...existing, + reversible: event.reversible, + tool_call_id: event.lc_tool_call_id ?? existing.tool_call_id, + chat_turn_id: event.chat_turn_id ?? existing.chat_turn_id, + }; + } + dbg("applyActionLogSse: merged into existing entry", { + id: event.id, + tool_call_id: merged[existingIdx]?.tool_call_id, + reversible: merged[existingIdx]?.reversible, + }); + return { ...prev, items: merged }; + } + dbg("applyActionLogSse: appended new placeholder", { + id: event.id, + tool_call_id: placeholder.tool_call_id, + tool_name: placeholder.tool_name, + reversible: placeholder.reversible, + cacheSizeAfter: prev.items.length + 1, + }); + // REST returns newest-first — keep that ordering when + // the server eventually refetches by prepending. + return { + ...prev, + items: [placeholder, ...prev.items], + total: prev.total + 1, + }; + } + ); +} + +/** + * Apply a post-SAVEPOINT reversibility flip + * (``data-action-log-updated`` SSE event) to the cache. + */ +export function applyActionLogUpdatedSse( + queryClient: QueryClient, + threadId: number, + id: number, + reversible: boolean +): void { + dbg("applyActionLogUpdatedSse: reversibility flip", { + threadId, + id, + reversible, + }); + queryClient.setQueryData( + agentActionsQueryKey(threadId), + (prev) => { + if (!prev) { + dbg("applyActionLogUpdatedSse: NO prev cache for thread; flip dropped", { + threadId, + id, + }); + return prev; + } + let mutated = false; + const items = prev.items.map((a) => { + if (a.id !== id) return a; + mutated = true; + return { ...a, reversible }; + }); + if (!mutated) { + dbg("applyActionLogUpdatedSse: id not in cache; flip dropped", { + threadId, + id, + cacheSize: prev.items.length, + cacheIds: prev.items.map((a) => a.id), + }); + } + return mutated ? { ...prev, items } : prev; + } + ); +} + +/** + * Optimistically mark ``id`` as reverted. + * + * Used by the inline / per-turn Revert button immediately after the + * server returns success so the UI flips to "Reverted" without + * waiting for a refetch. ``newActionId`` is the id of the new + * ``is_revert_action`` row the server inserted; pass ``null`` if the + * server didn't return it. + */ +export function markActionRevertedInCache( + queryClient: QueryClient, + threadId: number, + id: number, + newActionId: number | null +): void { + queryClient.setQueryData( + agentActionsQueryKey(threadId), + (prev) => { + if (!prev) return prev; + let mutated = false; + const items = prev.items.map((a) => { + if (a.id !== id) return a; + mutated = true; + // ``-1`` is a sentinel meaning "we know it was reverted + // but the server didn't tell us the new row's id". + return { + ...a, + reverted_by_action_id: newActionId ?? -1, + }; + }); + return mutated ? { ...prev, items } : prev; + } + ); +} + +/** + * Apply a batch of revert results (per-turn revert response) to the + * cache. Anything in the ``reverted`` / ``already_reverted`` buckets + * gets its ``reverted_by_action_id`` set; other rows are left alone. + */ +export function applyRevertTurnResultsToCache( + queryClient: QueryClient, + threadId: number, + entries: Array<{ id: number; newActionId: number | null }> +): void { + if (entries.length === 0) return; + queryClient.setQueryData( + agentActionsQueryKey(threadId), + (prev) => { + if (!prev) return prev; + const lookup = new Map(entries.map((e) => [e.id, e.newActionId])); + let mutated = false; + const items = prev.items.map((a) => { + if (!lookup.has(a.id)) return a; + mutated = true; + const newActionId = lookup.get(a.id) ?? null; + return { ...a, reverted_by_action_id: newActionId ?? -1 }; + }); + return mutated ? { ...prev, items } : prev; + } + ); +} + +/** + * Read-side hook used by the card, the turn button, the sheet, and + * the edit-from-position pre-flight. + * + * Returns the raw query state plus convenience selectors so consumers + * don't reach into ``data.items`` directly. ``enabled`` is the only + * knob — pass ``false`` to keep the query dormant when the consumer + * doesn't yet have a thread id. + */ +export function useAgentActionsQuery( + threadId: number | null, + options: { enabled?: boolean } = {} +) { + const enabled = (options.enabled ?? true) && threadId !== null; + const query = useQuery({ + queryKey: agentActionsQueryKey(threadId), + queryFn: async () => { + dbg("useAgentActionsQuery: REST fetch START", { + threadId, + pageSize: ACTION_LOG_PAGE_SIZE, + }); + const res = await agentActionsApiService.listForThread(threadId as number, { + page: 0, + pageSize: ACTION_LOG_PAGE_SIZE, + }); + dbg("useAgentActionsQuery: REST fetch DONE", { + threadId, + total: res.total, + returned: res.items.length, + items: res.items.map((a) => ({ + id: a.id, + tool_name: a.tool_name, + tool_call_id: a.tool_call_id, + reversible: a.reversible, + reverted_by_action_id: a.reverted_by_action_id, + is_revert_action: a.is_revert_action, + })), + }); + return res; + }, + enabled, + staleTime: 15 * 1000, + }); + + const items = useMemo(() => query.data?.items ?? [], [query.data]); + + // Index ``items`` once per change so the lookups below are O(1) + // instead of O(N) per card per render. With the cache sized to 200 + // rows and many tool cards visible at once, the unindexed scan was + // the hottest path on every assistant text-delta. (Vercel React + // rule ``js-index-maps`` / ``js-set-map-lookups``.) + const byToolCallId = useMemo(() => { + const m = new Map(); + for (const a of items) { + if (a.tool_call_id) m.set(a.tool_call_id, a); + } + return m; + }, [items]); + + // Pre-grouped + pre-sorted (oldest-first, the order the agent + // actually executed them in) so the (chat_turn_id, tool_name, + // position) fallback in ``tool-fallback.tsx`` is also O(1) per + // card. Excludes ``is_revert_action`` rows so the position index + // matches the agent's original execution order. + const byTurnAndTool = useMemo(() => { + const m = new Map(); + for (const a of items) { + if (!a.chat_turn_id || a.is_revert_action) continue; + const key = `${a.chat_turn_id}::${a.tool_name}`; + const bucket = m.get(key); + if (bucket) bucket.push(a); + else m.set(key, [a]); + } + for (const bucket of m.values()) { + bucket.sort( + (a, b) => + new Date(a.created_at).getTime() - new Date(b.created_at).getTime() + ); + } + return m; + }, [items]); + + // Snapshot the cache shape when its size changes — easiest way to + // spot when the cache is empty or stale at the moment a card + // mounts. Tracked on a ref so we don't re-run the diff on + // reference-equal cache reads. + const lastSnapshotRef = useRef<{ threadId: number | null; size: number } | null>(null); + useEffect(() => { + const last = lastSnapshotRef.current; + if (!last || last.threadId !== threadId || last.size !== items.length) { + dbg("useAgentActionsQuery: cache snapshot", { + threadId, + enabled, + itemCount: items.length, + itemKeys: items.slice(0, 8).map((a) => ({ + id: a.id, + tool_name: a.tool_name, + tool_call_id: a.tool_call_id, + chat_turn_id: a.chat_turn_id, + reversible: a.reversible, + })), + }); + lastSnapshotRef.current = { threadId, size: items.length }; + } + }, [threadId, enabled, items]); + + const findByToolCallId = useCallback( + (toolCallId: string | null | undefined): AgentAction | null => { + if (!toolCallId) return null; + const found = byToolCallId.get(toolCallId) ?? null; + if (!found && items.length > 0) { + dbg("findByToolCallId: MISS", { + queriedToolCallId: toolCallId, + itemCount: items.length, + availableToolCallIds: Array.from(byToolCallId.keys()), + }); + } + return found; + }, + [byToolCallId, items.length] + ); + + const findByChatTurnId = useCallback( + (chatTurnId: string | null | undefined): AgentAction[] => { + if (!chatTurnId) return []; + // Per-turn aggregation is uncommon enough (only the + // "Revert turn" button uses it) that re-scanning is fine; + // indexing it would just bloat memory. + return items.filter((a) => a.chat_turn_id === chatTurnId); + }, + [items] + ); + + const findByChatTurnAndTool = useCallback( + ( + chatTurnId: string | null | undefined, + toolName: string | null | undefined + ): AgentAction[] => { + if (!chatTurnId || !toolName) return []; + return byTurnAndTool.get(`${chatTurnId}::${toolName}`) ?? []; + }, + [byTurnAndTool] + ); + + return { + ...query, + items, + findByToolCallId, + findByChatTurnId, + findByChatTurnAndTool, + }; +} diff --git a/surfsense_web/hooks/use-messages-sync.ts b/surfsense_web/hooks/use-messages-sync.ts index ddbe8a757..5ccda23a5 100644 --- a/surfsense_web/hooks/use-messages-sync.ts +++ b/surfsense_web/hooks/use-messages-sync.ts @@ -31,6 +31,14 @@ export function useMessagesSync( content: msg.content, author_id: msg.authorId ?? null, created_at: new Date(msg.createdAt).toISOString(), + // Forward the per-turn correlation id so post-stream Zero + // re-syncs preserve ``metadata.custom.chatTurnId`` on the + // converted ``ThreadMessageLike``. Without this the inline + // Revert button's ``(chat_turn_id, tool_name, position)`` + // fallback breaks the moment Zero overwrites the messages + // state after a live stream completes (see + // ``handleSyncedMessagesUpdate`` in the chat page). + turn_id: msg.turnId ?? null, })); onMessagesUpdateRef.current(mapped); diff --git a/surfsense_web/lib/chat/streaming-state.ts b/surfsense_web/lib/chat/streaming-state.ts index 26fd7b98c..54faf7e7c 100644 --- a/surfsense_web/lib/chat/streaming-state.ts +++ b/surfsense_web/lib/chat/streaming-state.ts @@ -16,6 +16,23 @@ export type ContentPart = toolName: string; args: Record; result?: unknown; + /** + * Live / finalized JSON text for the tool's input arguments. + * + * - During streaming: accumulated partial JSON text from + * ``tool-input-delta`` events (may be invalid JSON + * mid-stream). assistant-ui's argsText parser tolerates + * invalid JSON gracefully (changelog 0.7.32 / 0.7.78). + * - On completion (``tool-input-available``): replaced with + * ``JSON.stringify(input, null, 2)`` so the post-stream + * card renders pretty-printed JSON instead of the + * model's possibly-fragmented formatting. + * + * Per assistant-ui ``ThreadMessageLike`` precedence + * (changelog 0.11.6 ``d318c83``), when ``argsText`` is + * supplied it wins over ``JSON.stringify(args)``. + */ + argsText?: string; /** * Authoritative LangChain ``tool_call.id`` propagated by the backend * via ``langchainToolCallId`` on tool-input-start/available and @@ -282,12 +299,22 @@ export function findToolCallIdByLcId( export function updateToolCall( state: ContentPartsState, toolCallId: string, - update: { args?: Record; result?: unknown; langchainToolCallId?: string } + update: { + args?: Record; + argsText?: string; + result?: unknown; + langchainToolCallId?: string; + } ): void { const index = state.toolCallIndices.get(toolCallId); if (index !== undefined && state.contentParts[index]?.type === "tool-call") { const tc = state.contentParts[index] as ContentPart & { type: "tool-call" }; if (update.args) tc.args = update.args; + // ``!== undefined`` (NOT a truthy check): an explicit empty + // string CAN clear, and a finalization with + // ``JSON.stringify({}, null, 2) === "{}"`` (truthy but + // represents an empty-input call) still applies. + if (update.argsText !== undefined) tc.argsText = update.argsText; if (update.result !== undefined) tc.result = update.result; // Only backfill langchainToolCallId if not already set — the // authoritative ``on_tool_end`` value should override an earlier @@ -299,6 +326,25 @@ export function updateToolCall( } } +/** + * Append a streamed args-delta chunk to the active tool call's + * ``argsText``. No-ops when no card has been registered yet for the + * given ``toolCallId`` (the matching ``tool-input-start`` either lost + * the wire race or this id never had a card — either way the deltas + * have nowhere safe to land). + */ +export function appendToolInputDelta( + state: ContentPartsState, + toolCallId: string, + delta: string +): void { + const idx = state.toolCallIndices.get(toolCallId); + if (idx === undefined) return; + const tc = state.contentParts[idx]; + if (tc?.type !== "tool-call") return; + tc.argsText = (tc.argsText ?? "") + delta; +} + function _hasInterruptResult(part: ContentPart): boolean { if (part.type !== "tool-call") return false; const r = (part as { result?: unknown }).result; @@ -371,6 +417,18 @@ export type SSEEvent = /** Authoritative LangChain ``tool_call.id``. Optional. */ langchainToolCallId?: string; } + | { + /** + * Live tool-call argument delta. Concatenated into + * ``argsText`` on the matching ``tool-call`` content part + * by ``appendToolInputDelta``. parity_v2 only — the legacy + * code path emits ``tool-input-available`` without prior + * deltas. + */ + type: "tool-input-delta"; + toolCallId: string; + inputTextDelta: string; + } | { type: "tool-input-available"; toolCallId: string; diff --git a/surfsense_web/zero/schema/chat.ts b/surfsense_web/zero/schema/chat.ts index 0293059fd..fb3d7651e 100644 --- a/surfsense_web/zero/schema/chat.ts +++ b/surfsense_web/zero/schema/chat.ts @@ -8,6 +8,13 @@ export const newChatMessageTable = table("new_chat_messages") threadId: number().from("thread_id"), authorId: string().optional().from("author_id"), createdAt: number().from("created_at"), + // Per-turn correlation id sourced from ``configurable.turn_id`` + // at streaming time. Required by the inline Revert button's + // (chat_turn_id, tool_name, position) fallback in tool-fallback.tsx + // — without it the live-collab Zero sync would clobber the + // metadata we set during streaming and the button would vanish + // the moment Zero re-syncs after the stream finishes. + turnId: string().optional().from("turn_id"), }) .primaryKey("id");