mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 17:22:38 +02:00
refactor(chat): stream agent events via stream_output and remove parity v2 flag
This commit is contained in:
parent
7e07092f67
commit
78f4747382
17 changed files with 76 additions and 1676 deletions
|
|
@ -324,7 +324,6 @@ SURFSENSE_ENABLE_ACTION_LOG=true
|
||||||
SURFSENSE_ENABLE_REVERT_ROUTE=true
|
SURFSENSE_ENABLE_REVERT_ROUTE=true
|
||||||
SURFSENSE_ENABLE_PERMISSION=true
|
SURFSENSE_ENABLE_PERMISSION=true
|
||||||
SURFSENSE_ENABLE_DOOM_LOOP=true
|
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||||
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
|
|
||||||
|
|
||||||
# Periodic connector sync interval (default: 5m)
|
# Periodic connector sync interval (default: 5m)
|
||||||
# SCHEDULE_CHECKER_INTERVAL=5m
|
# SCHEDULE_CHECKER_INTERVAL=5m
|
||||||
|
|
|
||||||
|
|
@ -315,14 +315,6 @@ LANGSMITH_PROJECT=surfsense
|
||||||
# SURFSENSE_ENABLE_ACTION_LOG=false
|
# SURFSENSE_ENABLE_ACTION_LOG=false
|
||||||
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
|
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
|
||||||
|
|
||||||
# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk
|
|
||||||
# content (typed reasoning blocks, tool-input deltas) and propagate the
|
|
||||||
# real tool_call_id to the SSE layer. When OFF, the stream falls back to
|
|
||||||
# the str-only text path and synthetic "call_<run_id>" tool-call ids.
|
|
||||||
# Schema migrations 135/136 ship unconditionally because they are
|
|
||||||
# forward-compatible.
|
|
||||||
# SURFSENSE_ENABLE_STREAM_PARITY_V2=false
|
|
||||||
|
|
||||||
# Plugins
|
# Plugins
|
||||||
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
||||||
# Comma-separated allowlist of plugin entry-point names
|
# Comma-separated allowlist of plugin entry-point names
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,6 @@ Defaults:
|
||||||
SURFSENSE_ENABLE_PERMISSION=true
|
SURFSENSE_ENABLE_PERMISSION=true
|
||||||
SURFSENSE_ENABLE_DOOM_LOOP=true
|
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
||||||
SURFSENSE_ENABLE_STREAM_PARITY_V2=true
|
|
||||||
|
|
||||||
Master kill-switch (overrides everything else):
|
Master kill-switch (overrides everything else):
|
||||||
|
|
||||||
|
|
@ -88,15 +87,6 @@ class AgentFeatureFlags:
|
||||||
enable_action_log: bool = True
|
enable_action_log: bool = True
|
||||||
enable_revert_route: bool = True
|
enable_revert_route: bool = True
|
||||||
|
|
||||||
# Streaming parity v2 — opt in to LangChain's structured
|
|
||||||
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
|
|
||||||
# deltas) and propagate the real ``tool_call_id`` to the SSE layer.
|
|
||||||
# When OFF the ``stream_new_chat`` task falls back to the str-only
|
|
||||||
# text path and the synthetic ``call_<run_id>`` tool-call id (no
|
|
||||||
# ``langchainToolCallId`` propagation). Schema migrations 135/136
|
|
||||||
# ship unconditionally because they're forward-compatible.
|
|
||||||
enable_stream_parity_v2: bool = True
|
|
||||||
|
|
||||||
# Plugins
|
# Plugins
|
||||||
enable_plugin_loader: bool = False
|
enable_plugin_loader: bool = False
|
||||||
|
|
||||||
|
|
@ -169,7 +159,6 @@ class AgentFeatureFlags:
|
||||||
enable_kb_planner_runnable=False,
|
enable_kb_planner_runnable=False,
|
||||||
enable_action_log=False,
|
enable_action_log=False,
|
||||||
enable_revert_route=False,
|
enable_revert_route=False,
|
||||||
enable_stream_parity_v2=False,
|
|
||||||
enable_plugin_loader=False,
|
enable_plugin_loader=False,
|
||||||
enable_otel=False,
|
enable_otel=False,
|
||||||
enable_agent_cache=False,
|
enable_agent_cache=False,
|
||||||
|
|
@ -208,10 +197,6 @@ class AgentFeatureFlags:
|
||||||
# Snapshot / revert
|
# Snapshot / revert
|
||||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
|
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
|
||||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
|
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
|
||||||
# Streaming parity v2
|
|
||||||
enable_stream_parity_v2=_env_bool(
|
|
||||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2", True
|
|
||||||
),
|
|
||||||
# Plugins
|
# Plugins
|
||||||
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
||||||
# Observability
|
# Observability
|
||||||
|
|
|
||||||
|
|
@ -608,15 +608,14 @@ class VercelStreamingService:
|
||||||
Args:
|
Args:
|
||||||
tool_call_id: The unique tool call identifier. May be EITHER the
|
tool_call_id: The unique tool call identifier. May be EITHER the
|
||||||
synthetic ``call_<run_id>`` id derived from LangGraph
|
synthetic ``call_<run_id>`` id derived from LangGraph
|
||||||
``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2``
|
``run_id`` (unmatched chunk fallback when no ``index`` was
|
||||||
OFF, or the unmatched-fallback path under parity_v2) OR
|
registered) OR the authoritative LangChain ``tool_call.id``
|
||||||
the authoritative LangChain ``tool_call.id`` (parity_v2
|
(when the provider streams ``tool_call_chunks`` we register
|
||||||
path: when the provider streams ``tool_call_chunks`` we
|
the ``index`` and reuse the lc-id as the card id so live
|
||||||
register the ``index`` and reuse the lc-id as the card
|
``tool-input-delta`` events route without a downstream join).
|
||||||
id so live ``tool-input-delta`` events can be routed
|
Either way, the same id is preserved across
|
||||||
without a downstream join). Either way, the same id is
|
``tool-input-start`` / ``-delta`` / ``-available`` /
|
||||||
preserved across ``tool-input-start`` / ``-delta`` /
|
``tool-output-available`` for one call.
|
||||||
``-available`` / ``tool-output-available`` for one call.
|
|
||||||
tool_name: The name of the tool being called.
|
tool_name: The name of the tool being called.
|
||||||
langchain_tool_call_id: Optional authoritative LangChain
|
langchain_tool_call_id: Optional authoritative LangChain
|
||||||
``tool_call.id``. When set, surfaces as
|
``tool_call.id``. When set, surfaces as
|
||||||
|
|
|
||||||
|
|
@ -85,8 +85,8 @@ class AssistantContentBuilder:
|
||||||
self._current_text_idx: int = -1
|
self._current_text_idx: int = -1
|
||||||
self._current_reasoning_idx: int = -1
|
self._current_reasoning_idx: int = -1
|
||||||
# ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the
|
# ``ui_id``-keyed indexes for tool-call parts. ``ui_id`` is the
|
||||||
# synthetic ``call_<run_id>`` (legacy) or the LangChain
|
# synthetic ``call_<run_id>`` (chunk fallback) or the LangChain
|
||||||
# ``tool_call.id`` (parity_v2) — same key the streaming layer
|
# ``tool_call.id`` (indexed chunk path) — same key the streaming layer
|
||||||
# threads through every ``tool-input-*`` / ``tool-output-*`` event.
|
# threads through every ``tool-input-*`` / ``tool-output-*`` event.
|
||||||
self._tool_call_idx_by_ui_id: dict[str, int] = {}
|
self._tool_call_idx_by_ui_id: dict[str, int] = {}
|
||||||
# Live argsText accumulator (concatenated ``tool-input-delta`` chunks)
|
# Live argsText accumulator (concatenated ``tool-input-delta`` chunks)
|
||||||
|
|
@ -181,7 +181,7 @@ class AssistantContentBuilder:
|
||||||
"""Register a tool-call card. Args are filled in by later events."""
|
"""Register a tool-call card. Args are filled in by later events."""
|
||||||
if not ui_id:
|
if not ui_id:
|
||||||
return
|
return
|
||||||
# Skip duplicate registration: parity_v2 may emit
|
# Skip duplicate registration: the stream may emit
|
||||||
# ``tool-input-start`` from both ``on_chat_model_stream``
|
# ``tool-input-start`` from both ``on_chat_model_stream``
|
||||||
# (when tool_call_chunks register a name) and ``on_tool_start``
|
# (when tool_call_chunks register a name) and ``on_tool_start``
|
||||||
# (the canonical path). The FE de-dupes via ``toolCallIndices``;
|
# (the canonical path). The FE de-dupes via ``toolCallIndices``;
|
||||||
|
|
@ -243,7 +243,7 @@ class AssistantContentBuilder:
|
||||||
pretty-printed JSON, sets the full ``args`` dict, and backfills
|
pretty-printed JSON, sets the full ``args`` dict, and backfills
|
||||||
``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time.
|
``langchainToolCallId`` if it wasn't known at ``tool-input-start`` time.
|
||||||
Also creates the card if no prior ``tool-input-start`` registered it
|
Also creates the card if no prior ``tool-input-start`` registered it
|
||||||
(legacy parity_v2-OFF / late-registration paths).
|
(late-registration when no prior ``tool-input-start``).
|
||||||
"""
|
"""
|
||||||
if not ui_id:
|
if not ui_id:
|
||||||
return
|
return
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -5,7 +5,6 @@ from __future__ import annotations
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.agents.new_chat.feature_flags import get_flags
|
|
||||||
from app.tasks.chat.streaming.graph_stream.result import StreamingResult
|
from app.tasks.chat.streaming.graph_stream.result import StreamingResult
|
||||||
from app.tasks.chat.streaming.relay.event_relay import EventRelay
|
from app.tasks.chat.streaming.relay.event_relay import EventRelay
|
||||||
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
||||||
|
|
@ -30,7 +29,6 @@ async def stream_output(
|
||||||
initial_step_id=initial_step_id,
|
initial_step_id=initial_step_id,
|
||||||
initial_step_title=initial_step_title,
|
initial_step_title=initial_step_title,
|
||||||
initial_step_items=initial_step_items,
|
initial_step_items=initial_step_items,
|
||||||
parity_v2=bool(get_flags().enable_stream_parity_v2),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"}
|
astream_kwargs: dict[str, Any] = {"config": config, "version": "v2"}
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ def iter_chat_model_stream_frames(
|
||||||
reasoning_delta = parts["reasoning"]
|
reasoning_delta = parts["reasoning"]
|
||||||
text_delta = parts["text"]
|
text_delta = parts["text"]
|
||||||
|
|
||||||
if state.parity_v2 and reasoning_delta:
|
if reasoning_delta:
|
||||||
if state.current_text_id is not None:
|
if state.current_text_id is not None:
|
||||||
yield streaming_service.format_text_end(state.current_text_id)
|
yield streaming_service.format_text_end(state.current_text_id)
|
||||||
if content_builder is not None:
|
if content_builder is not None:
|
||||||
|
|
@ -100,7 +100,7 @@ def iter_chat_model_stream_frames(
|
||||||
if content_builder is not None:
|
if content_builder is not None:
|
||||||
content_builder.on_text_delta(state.current_text_id, text_delta)
|
content_builder.on_text_delta(state.current_text_id, text_delta)
|
||||||
|
|
||||||
if state.parity_v2 and parts["tool_call_chunks"]:
|
if parts["tool_call_chunks"]:
|
||||||
for tcc in parts["tool_call_chunks"]:
|
for tcc in parts["tool_call_chunks"]:
|
||||||
idx = tcc.get("index")
|
idx = tcc.get("index")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,12 +77,11 @@ def iter_tool_start_frames(
|
||||||
yield emit_thinking_step_frame(**frame_kw)
|
yield emit_thinking_step_frame(**frame_kw)
|
||||||
|
|
||||||
matched_meta: dict[str, str] | None = None
|
matched_meta: dict[str, str] | None = None
|
||||||
if state.parity_v2:
|
taken_ui_ids = set(state.ui_tool_call_id_by_run.values())
|
||||||
taken_ui_ids = set(state.ui_tool_call_id_by_run.values())
|
for meta in state.index_to_meta.values():
|
||||||
for meta in state.index_to_meta.values():
|
if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids:
|
||||||
if meta["name"] == tool_name and meta["ui_id"] not in taken_ui_ids:
|
matched_meta = meta
|
||||||
matched_meta = meta
|
break
|
||||||
break
|
|
||||||
|
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
langchain_tool_call_id: str | None = None
|
langchain_tool_call_id: str | None = None
|
||||||
|
|
@ -97,13 +96,12 @@ def iter_tool_start_frames(
|
||||||
if run_id
|
if run_id
|
||||||
else streaming_service.generate_tool_call_id()
|
else streaming_service.generate_tool_call_id()
|
||||||
)
|
)
|
||||||
if state.parity_v2:
|
langchain_tool_call_id = match_buffered_langchain_tool_call_id(
|
||||||
langchain_tool_call_id = match_buffered_langchain_tool_call_id(
|
state.pending_tool_call_chunks,
|
||||||
state.pending_tool_call_chunks,
|
tool_name,
|
||||||
tool_name,
|
run_id,
|
||||||
run_id,
|
state.lc_tool_call_id_by_run,
|
||||||
state.lc_tool_call_id_by_run,
|
)
|
||||||
)
|
|
||||||
yield streaming_service.format_tool_input_start(
|
yield streaming_service.format_tool_input_start(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
tool_name,
|
tool_name,
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ class AgentEventRelayState:
|
||||||
active_tool_depth: int = 0
|
active_tool_depth: int = 0
|
||||||
called_update_memory: bool = False
|
called_update_memory: bool = False
|
||||||
current_reasoning_id: str | None = None
|
current_reasoning_id: str | None = None
|
||||||
parity_v2: bool = False
|
|
||||||
pending_tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
pending_tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||||
lc_tool_call_id_by_run: dict[str, str] = field(default_factory=dict)
|
lc_tool_call_id_by_run: dict[str, str] = field(default_factory=dict)
|
||||||
file_path_by_run: dict[str, str] = field(default_factory=dict)
|
file_path_by_run: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
@ -39,7 +38,6 @@ class AgentEventRelayState:
|
||||||
initial_step_id: str | None = None,
|
initial_step_id: str | None = None,
|
||||||
initial_step_title: str = "",
|
initial_step_title: str = "",
|
||||||
initial_step_items: list[str] | None = None,
|
initial_step_items: list[str] | None = None,
|
||||||
parity_v2: bool,
|
|
||||||
) -> AgentEventRelayState:
|
) -> AgentEventRelayState:
|
||||||
counter = 1 if initial_step_id else 0
|
counter = 1 if initial_step_id else 0
|
||||||
return cls(
|
return cls(
|
||||||
|
|
@ -47,7 +45,6 @@ class AgentEventRelayState:
|
||||||
last_active_step_id=initial_step_id,
|
last_active_step_id=initial_step_id,
|
||||||
last_active_step_title=initial_step_title,
|
last_active_step_title=initial_step_title,
|
||||||
last_active_step_items=list(initial_step_items or []),
|
last_active_step_items=list(initial_step_items or []),
|
||||||
parity_v2=parity_v2,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def next_thinking_step_id(self, step_prefix: str) -> str:
|
def next_thinking_step_id(self, step_prefix: str) -> str:
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,6 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||||
"SURFSENSE_ENABLE_ACTION_LOG",
|
"SURFSENSE_ENABLE_ACTION_LOG",
|
||||||
"SURFSENSE_ENABLE_REVERT_ROUTE",
|
"SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
|
|
||||||
"SURFSENSE_ENABLE_PLUGIN_LOADER",
|
"SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||||
"SURFSENSE_ENABLE_OTEL",
|
"SURFSENSE_ENABLE_OTEL",
|
||||||
"SURFSENSE_ENABLE_AGENT_CACHE",
|
"SURFSENSE_ENABLE_AGENT_CACHE",
|
||||||
|
|
@ -61,7 +60,6 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) ->
|
||||||
assert flags.enable_kb_planner_runnable is True
|
assert flags.enable_kb_planner_runnable is True
|
||||||
assert flags.enable_action_log is True
|
assert flags.enable_action_log is True
|
||||||
assert flags.enable_revert_route is True
|
assert flags.enable_revert_route is True
|
||||||
assert flags.enable_stream_parity_v2 is True
|
|
||||||
assert flags.enable_plugin_loader is False
|
assert flags.enable_plugin_loader is False
|
||||||
assert flags.enable_otel is False
|
assert flags.enable_otel is False
|
||||||
# Phase 2: agent cache is now default-on (the prerequisite tool
|
# Phase 2: agent cache is now default-on (the prerequisite tool
|
||||||
|
|
@ -127,7 +125,6 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
|
||||||
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||||
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
|
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
|
||||||
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
|
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||||
"enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2",
|
|
||||||
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
|
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||||
"enable_otel": "SURFSENSE_ENABLE_OTEL",
|
"enable_otel": "SURFSENSE_ENABLE_OTEL",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -137,7 +137,7 @@ def test_complete_active_thinking_step_mirrors_closure_semantics() -> None:
|
||||||
|
|
||||||
|
|
||||||
def test_agent_event_relay_state_factory_matches_counter_rule() -> None:
|
def test_agent_event_relay_state_factory_matches_counter_rule() -> None:
|
||||||
s0 = AgentEventRelayState.for_invocation(parity_v2=False)
|
s0 = AgentEventRelayState.for_invocation()
|
||||||
assert s0.thinking_step_counter == 0
|
assert s0.thinking_step_counter == 0
|
||||||
assert s0.last_active_step_id is None
|
assert s0.last_active_step_id is None
|
||||||
|
|
||||||
|
|
@ -145,11 +145,9 @@ def test_agent_event_relay_state_factory_matches_counter_rule() -> None:
|
||||||
initial_step_id="thinking-resume-1",
|
initial_step_id="thinking-resume-1",
|
||||||
initial_step_title="Inherited",
|
initial_step_title="Inherited",
|
||||||
initial_step_items=["Topic: X"],
|
initial_step_items=["Topic: X"],
|
||||||
parity_v2=True,
|
|
||||||
)
|
)
|
||||||
assert s1.thinking_step_counter == 1
|
assert s1.thinking_step_counter == 1
|
||||||
assert s1.last_active_step_id == "thinking-resume-1"
|
assert s1.last_active_step_id == "thinking-resume-1"
|
||||||
assert s1.parity_v2 is True
|
|
||||||
assert s1.next_thinking_step_id("thinking") == "thinking-2"
|
assert s1.next_thinking_step_id("thinking") == "thinking-2"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -161,7 +161,7 @@ class TestToolHeavyTurn:
|
||||||
_assert_jsonb_safe(snap)
|
_assert_jsonb_safe(snap)
|
||||||
|
|
||||||
def test_tool_input_available_without_prior_start_creates_card(self):
|
def test_tool_input_available_without_prior_start_creates_card(self):
|
||||||
# Legacy / parity_v2-OFF path: tool-input-available may be
|
# Late-registration: tool-input-available may be
|
||||||
# emitted without a prior tool-input-start (no streamed
|
# emitted without a prior tool-input-start (no streamed
|
||||||
# tool_call_chunks). The card should still be created.
|
# tool_call_chunks). The card should still be created.
|
||||||
b = AssistantContentBuilder()
|
b = AssistantContentBuilder()
|
||||||
|
|
@ -187,7 +187,7 @@ class TestToolHeavyTurn:
|
||||||
assert part["result"] == {"matches": 3}
|
assert part["result"] == {"matches": 3}
|
||||||
|
|
||||||
def test_tool_input_start_idempotent_for_same_ui_id(self):
|
def test_tool_input_start_idempotent_for_same_ui_id(self):
|
||||||
# parity_v2: tool-input-start can fire from BOTH the chunk
|
# tool-input-start can fire from BOTH the chunk
|
||||||
# registration path AND the canonical ``on_tool_start`` path.
|
# registration path AND the canonical ``on_tool_start`` path.
|
||||||
# The second call must not create a duplicate part.
|
# The second call must not create a duplicate part.
|
||||||
b = AssistantContentBuilder()
|
b = AssistantContentBuilder()
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,13 @@
|
||||||
"""Unit tests for live tool-call argument streaming.
|
"""Unit tests for live tool-call argument streaming.
|
||||||
|
|
||||||
Pins the wire format that ``_stream_agent_events`` emits when
|
Pins the wire format that ``_stream_agent_events`` emits:
|
||||||
``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start`` →
|
``tool-input-start`` → ``tool-input-delta``... → ``tool-input-available`` →
|
||||||
``tool-input-delta``... → ``tool-input-available`` → ``tool-output-available``
|
``tool-output-available``, keyed consistently with LangChain ``tool_call.id``
|
||||||
all keyed by the same LangChain ``tool_call.id``.
|
when the model streams indexed chunks.
|
||||||
|
|
||||||
Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and
|
Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and
|
||||||
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to
|
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are internal to the
|
||||||
``_stream_agent_events`` so we exercise them via the public wire output.
|
streaming layer so we assert on the public SSE payloads.
|
||||||
|
|
||||||
These tests also lock in the legacy / parity_v2-OFF behaviour so the
|
|
||||||
synthetic ``call_<run_id>`` shape stays stable for older clients.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -22,8 +19,6 @@ from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import app.tasks.chat.stream_new_chat as stream_module
|
|
||||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
from app.tasks.chat.stream_new_chat import (
|
from app.tasks.chat.stream_new_chat import (
|
||||||
StreamResult,
|
StreamResult,
|
||||||
|
|
@ -164,24 +159,6 @@ def _tool_end(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
monkeypatch.setattr(
|
|
||||||
stream_module,
|
|
||||||
"get_flags",
|
|
||||||
lambda: AgentFeatureFlags(enable_stream_parity_v2=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
monkeypatch.setattr(
|
|
||||||
stream_module,
|
|
||||||
"get_flags",
|
|
||||||
lambda: AgentFeatureFlags(enable_stream_parity_v2=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _drain(
|
async def _drain(
|
||||||
events: list[dict[str, Any]], state: _FakeAgentState | None = None
|
events: list[dict[str, Any]], state: _FakeAgentState | None = None
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
|
|
@ -253,12 +230,12 @@ class TestLegacyMatch:
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# parity_v2 wire format tests.
|
# Tool input streaming wire format
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
|
async def test_idless_chunk_merging_by_index() -> None:
|
||||||
"""First chunk carries id+name; later idless chunks at the same
|
"""First chunk carries id+name; later idless chunks at the same
|
||||||
``index`` merge into the SAME ``tool-input-start`` ui id and emit
|
``index`` merge into the SAME ``tool-input-start`` ui id and emit
|
||||||
one ``tool-input-delta`` per chunk."""
|
one ``tool-input-delta`` per chunk."""
|
||||||
|
|
@ -302,9 +279,7 @@ async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_two_interleaved_tool_calls_route_by_index(
|
async def test_two_interleaved_tool_calls_route_by_index() -> None:
|
||||||
parity_v2_on: None,
|
|
||||||
) -> None:
|
|
||||||
"""Two same-name calls with distinct indices keep their deltas
|
"""Two same-name calls with distinct indices keep their deltas
|
||||||
routed to the right card."""
|
routed to the right card."""
|
||||||
events = [
|
events = [
|
||||||
|
|
@ -344,7 +319,7 @@ async def test_two_interleaved_tool_calls_route_by_index(
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
|
async def test_identity_stable_across_lifecycle() -> None:
|
||||||
"""Whatever id ``tool-input-start`` chose must be the SAME id used
|
"""Whatever id ``tool-input-start`` chose must be the SAME id used
|
||||||
on ``tool-input-available`` AND ``tool-output-available``."""
|
on ``tool-input-available`` AND ``tool-output-available``."""
|
||||||
events = [
|
events = [
|
||||||
|
|
@ -367,7 +342,7 @@ async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
|
async def test_no_duplicate_tool_input_start() -> None:
|
||||||
"""When the chunk-emission loop already fired ``tool-input-start``
|
"""When the chunk-emission loop already fired ``tool-input-start``
|
||||||
for this run, ``on_tool_start`` MUST NOT emit a second one."""
|
for this run, ``on_tool_start`` MUST NOT emit a second one."""
|
||||||
events = [
|
events = [
|
||||||
|
|
@ -386,9 +361,7 @@ async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_active_text_closes_before_early_tool_input_start(
|
async def test_active_text_closes_before_early_tool_input_start() -> None:
|
||||||
parity_v2_on: None,
|
|
||||||
) -> None:
|
|
||||||
"""Streaming a text-delta then a tool-call chunk in subsequent
|
"""Streaming a text-delta then a tool-call chunk in subsequent
|
||||||
chunks: the wire MUST contain ``text-end`` before the FIRST
|
chunks: the wire MUST contain ``text-end`` before the FIRST
|
||||||
``tool-input-start`` (clean part boundary on the frontend)."""
|
``tool-input-start`` (clean part boundary on the frontend)."""
|
||||||
|
|
@ -409,9 +382,7 @@ async def test_active_text_closes_before_early_tool_input_start(
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mixed_text_and_tool_chunk_preserve_order(
|
async def test_mixed_text_and_tool_chunk_preserve_order() -> None:
|
||||||
parity_v2_on: None,
|
|
||||||
) -> None:
|
|
||||||
"""One AIMessageChunk that carries BOTH ``text`` content AND
|
"""One AIMessageChunk that carries BOTH ``text`` content AND
|
||||||
``tool_call_chunks`` should emit the text delta FIRST, then close
|
``tool_call_chunks`` should emit the text delta FIRST, then close
|
||||||
text, then ``tool-input-start``+``tool-input-delta``."""
|
text, then ``tool-input-start``+``tool-input-delta``."""
|
||||||
|
|
@ -441,45 +412,7 @@ async def test_mixed_text_and_tool_chunk_preserve_order(
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parity_v2_off_preserves_legacy_shape(
|
async def test_skip_append_prevents_stale_id_reuse() -> None:
|
||||||
parity_v2_off: None,
|
|
||||||
) -> None:
|
|
||||||
"""When the flag is OFF, no deltas are emitted and the ``toolCallId``
|
|
||||||
is ``call_<run_id>`` (NOT the lc id)."""
|
|
||||||
events = [
|
|
||||||
_model_stream(
|
|
||||||
tool_call_chunks=[
|
|
||||||
{"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0}
|
|
||||||
]
|
|
||||||
),
|
|
||||||
_tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}),
|
|
||||||
_tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"),
|
|
||||||
]
|
|
||||||
payloads = await _drain(events)
|
|
||||||
|
|
||||||
assert _of_type(payloads, "tool-input-delta") == []
|
|
||||||
starts = _of_type(payloads, "tool-input-start")
|
|
||||||
assert len(starts) == 1
|
|
||||||
assert starts[0]["toolCallId"].startswith("call_run-A")
|
|
||||||
# No ``langchainToolCallId`` propagation on ``tool-input-start`` in
|
|
||||||
# legacy mode (the start event fires before the ToolMessage is
|
|
||||||
# available, so we can't extract the authoritative LangChain id yet).
|
|
||||||
assert "langchainToolCallId" not in starts[0]
|
|
||||||
output = _of_type(payloads, "tool-output-available")
|
|
||||||
assert output[0]["toolCallId"].startswith("call_run-A")
|
|
||||||
# ``tool-output-available`` MUST carry ``langchainToolCallId`` even
|
|
||||||
# in legacy mode: the chat tool card uses it to backfill the
|
|
||||||
# LangChain id and join against the ``data-action-log`` SSE event
|
|
||||||
# (keyed by ``lc_tool_call_id``) so the inline Revert button can
|
|
||||||
# light up. Sourced from the returned ``ToolMessage.tool_call_id``,
|
|
||||||
# which is populated regardless of feature-flag state.
|
|
||||||
assert output[0]["langchainToolCallId"] == "lc-1"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_skip_append_prevents_stale_id_reuse(
|
|
||||||
parity_v2_on: None,
|
|
||||||
) -> None:
|
|
||||||
"""Two same-name tools: the SECOND tool's ``langchainToolCallId``
|
"""Two same-name tools: the SECOND tool's ``langchainToolCallId``
|
||||||
must NOT come from the first tool's chunk (``pending_tool_call_chunks``
|
must NOT come from the first tool's chunk (``pending_tool_call_chunks``
|
||||||
must stay empty for indexed-registered chunks)."""
|
must stay empty for indexed-registered chunks)."""
|
||||||
|
|
@ -506,9 +439,7 @@ async def test_skip_append_prevents_stale_id_reuse(
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_registration_waits_for_both_id_and_name(
|
async def test_registration_waits_for_both_id_and_name() -> None:
|
||||||
parity_v2_on: None,
|
|
||||||
) -> None:
|
|
||||||
"""An id-only chunk (no name yet) must NOT emit ``tool-input-start``."""
|
"""An id-only chunk (no name yet) must NOT emit ``tool-input-start``."""
|
||||||
events = [
|
events = [
|
||||||
_model_stream(
|
_model_stream(
|
||||||
|
|
@ -520,12 +451,9 @@ async def test_registration_waits_for_both_id_and_name(
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unmatched_fallback_still_attaches_lc_id(
|
async def test_unmatched_fallback_still_attaches_lc_id() -> None:
|
||||||
parity_v2_on: None,
|
"""When the provider omits chunk ``index``, buffered chunks still get a
|
||||||
) -> None:
|
``tool-input-start`` with the matching ``langchainToolCallId``."""
|
||||||
"""parity_v2 ON, but the provider didn't include an ``index``: the
|
|
||||||
legacy fallback path must still emit ``tool-input-start`` with the
|
|
||||||
matching ``langchainToolCallId``."""
|
|
||||||
events = [
|
events = [
|
||||||
# No index on the chunk → not registered into index_to_meta;
|
# No index on the chunk → not registered into index_to_meta;
|
||||||
# falls through to ``pending_tool_call_chunks`` so the legacy
|
# falls through to ``pending_tool_call_chunks`` so the legacy
|
||||||
|
|
@ -542,9 +470,7 @@ async def test_unmatched_fallback_still_attaches_lc_id(
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_interrupt_request_uses_task_that_contains_interrupt(
|
async def test_interrupt_request_uses_task_that_contains_interrupt() -> None:
|
||||||
parity_v2_on: None,
|
|
||||||
) -> None:
|
|
||||||
interrupt_payload = {
|
interrupt_payload = {
|
||||||
"type": "calendar_event_create",
|
"type": "calendar_event_create",
|
||||||
"action": {
|
"action": {
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@ import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Renders the structured `reasoning` part emitted by the backend's
|
* Renders the structured `reasoning` part emitted by the backend stream
|
||||||
* stream-parity v2 path (A1).
|
* (typed reasoning deltas from the chat model).
|
||||||
*
|
*
|
||||||
* Behaviour mirrors the existing `ThinkingStepsDisplay`:
|
* Behaviour mirrors the existing `ThinkingStepsDisplay`:
|
||||||
* - collapsed by default;
|
* - collapsed by default;
|
||||||
|
|
|
||||||
|
|
@ -48,13 +48,11 @@ import { cn } from "@/lib/utils";
|
||||||
* stream, post-stream reversibility flip, and explicit revert clicks.
|
* stream, post-stream reversibility flip, and explicit revert clicks.
|
||||||
*
|
*
|
||||||
* Match key (in priority order):
|
* Match key (in priority order):
|
||||||
* 1. ``a.tool_call_id === toolCallId`` — direct hit in parity_v2 when
|
* 1. ``a.tool_call_id === toolCallId`` — direct hit when the model
|
||||||
* the model streamed ``tool_call_chunks`` so the card's synthetic
|
* streamed ``tool_call_chunks`` so the card id matches the LangChain id.
|
||||||
* id IS the LangChain id.
|
* 2. ``a.tool_call_id === langchainToolCallId`` — synthetic card id is
|
||||||
* 2. ``a.tool_call_id === langchainToolCallId`` — legacy mode (or
|
* ``call_<run_id>`` and the LangChain id is backfilled by
|
||||||
* parity_v2 with provider-side chunk emission) where the card's
|
* ``tool-output-available``.
|
||||||
* synthetic id is ``call_<run_id>`` and the LangChain id is
|
|
||||||
* backfilled onto the part by ``tool-output-available``.
|
|
||||||
* 3. ``(chat_turn_id, tool_name, position-within-turn)`` — fallback
|
* 3. ``(chat_turn_id, tool_name, position-within-turn)`` — fallback
|
||||||
* for cards whose synthetic id is ``call_<run_id>`` AND whose
|
* for cards whose synthetic id is ``call_<run_id>`` AND whose
|
||||||
* ``langchainToolCallId`` never got backfilled (provider emitted
|
* ``langchainToolCallId`` never got backfilled (provider emitted
|
||||||
|
|
@ -116,7 +114,7 @@ function ToolCardRevertButton({
|
||||||
|
|
||||||
const action = useMemo(() => {
|
const action = useMemo(() => {
|
||||||
// Tier 1 + 2: O(1) Map-backed direct id match. Covers
|
// Tier 1 + 2: O(1) Map-backed direct id match. Covers
|
||||||
// ~all parity_v2 streams and any legacy stream that backfilled
|
// Indexed chunk streams and any stream that backfilled
|
||||||
// ``langchainToolCallId`` via ``tool-output-available``.
|
// ``langchainToolCallId`` via ``tool-output-available``.
|
||||||
const direct = findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId);
|
const direct = findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId);
|
||||||
if (direct) return direct;
|
if (direct) return direct;
|
||||||
|
|
|
||||||
|
|
@ -421,9 +421,8 @@ export type SSEEvent =
|
||||||
/**
|
/**
|
||||||
* Live tool-call argument delta. Concatenated into
|
* Live tool-call argument delta. Concatenated into
|
||||||
* ``argsText`` on the matching ``tool-call`` content part
|
* ``argsText`` on the matching ``tool-call`` content part
|
||||||
* by ``appendToolInputDelta``. parity_v2 only — the legacy
|
* by ``appendToolInputDelta``. Some providers emit
|
||||||
* code path emits ``tool-input-available`` without prior
|
* ``tool-input-available`` without prior deltas.
|
||||||
* deltas.
|
|
||||||
*/
|
*/
|
||||||
type: "tool-input-delta";
|
type: "tool-input-delta";
|
||||||
toolCallId: string;
|
toolCallId: string;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue