mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 09:12:40 +02:00
refactor(chat): stream agent events via stream_output and remove parity v2 flag
This commit is contained in:
parent
7e07092f67
commit
78f4747382
17 changed files with 76 additions and 1676 deletions
|
|
@ -31,7 +31,6 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
|
||||
"SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
"SURFSENSE_ENABLE_OTEL",
|
||||
"SURFSENSE_ENABLE_AGENT_CACHE",
|
||||
|
|
@ -61,7 +60,6 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) ->
|
|||
assert flags.enable_kb_planner_runnable is True
|
||||
assert flags.enable_action_log is True
|
||||
assert flags.enable_revert_route is True
|
||||
assert flags.enable_stream_parity_v2 is True
|
||||
assert flags.enable_plugin_loader is False
|
||||
assert flags.enable_otel is False
|
||||
# Phase 2: agent cache is now default-on (the prerequisite tool
|
||||
|
|
@ -127,7 +125,6 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
|
|||
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2",
|
||||
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
"enable_otel": "SURFSENSE_ENABLE_OTEL",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ def test_complete_active_thinking_step_mirrors_closure_semantics() -> None:
|
|||
|
||||
|
||||
def test_agent_event_relay_state_factory_matches_counter_rule() -> None:
|
||||
s0 = AgentEventRelayState.for_invocation(parity_v2=False)
|
||||
s0 = AgentEventRelayState.for_invocation()
|
||||
assert s0.thinking_step_counter == 0
|
||||
assert s0.last_active_step_id is None
|
||||
|
||||
|
|
@ -145,11 +145,9 @@ def test_agent_event_relay_state_factory_matches_counter_rule() -> None:
|
|||
initial_step_id="thinking-resume-1",
|
||||
initial_step_title="Inherited",
|
||||
initial_step_items=["Topic: X"],
|
||||
parity_v2=True,
|
||||
)
|
||||
assert s1.thinking_step_counter == 1
|
||||
assert s1.last_active_step_id == "thinking-resume-1"
|
||||
assert s1.parity_v2 is True
|
||||
assert s1.next_thinking_step_id("thinking") == "thinking-2"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ class TestToolHeavyTurn:
|
|||
_assert_jsonb_safe(snap)
|
||||
|
||||
def test_tool_input_available_without_prior_start_creates_card(self):
|
||||
# Legacy / parity_v2-OFF path: tool-input-available may be
|
||||
# Late-registration: tool-input-available may be
|
||||
# emitted without a prior tool-input-start (no streamed
|
||||
# tool_call_chunks). The card should still be created.
|
||||
b = AssistantContentBuilder()
|
||||
|
|
@ -187,7 +187,7 @@ class TestToolHeavyTurn:
|
|||
assert part["result"] == {"matches": 3}
|
||||
|
||||
def test_tool_input_start_idempotent_for_same_ui_id(self):
|
||||
# parity_v2: tool-input-start can fire from BOTH the chunk
|
||||
# tool-input-start can fire from BOTH the chunk
|
||||
# registration path AND the canonical ``on_tool_start`` path.
|
||||
# The second call must not create a duplicate part.
|
||||
b = AssistantContentBuilder()
|
||||
|
|
|
|||
|
|
@ -1,16 +1,13 @@
|
|||
"""Unit tests for live tool-call argument streaming.
|
||||
|
||||
Pins the wire format that ``_stream_agent_events`` emits when
|
||||
``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start`` →
|
||||
``tool-input-delta``... → ``tool-input-available`` → ``tool-output-available``
|
||||
all keyed by the same LangChain ``tool_call.id``.
|
||||
Pins the wire format that ``_stream_agent_events`` emits:
|
||||
``tool-input-start`` → ``tool-input-delta``... → ``tool-input-available`` →
|
||||
``tool-output-available``, keyed consistently with LangChain ``tool_call.id``
|
||||
when the model streams indexed chunks.
|
||||
|
||||
Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and
|
||||
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to
|
||||
``_stream_agent_events`` so we exercise them via the public wire output.
|
||||
|
||||
These tests also lock in the legacy / parity_v2-OFF behaviour so the
|
||||
synthetic ``call_<run_id>`` shape stays stable for older clients.
|
||||
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are internal to the
|
||||
streaming layer so we assert on the public SSE payloads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -22,8 +19,6 @@ from typing import Any
|
|||
|
||||
import pytest
|
||||
|
||||
import app.tasks.chat.stream_new_chat as stream_module
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
StreamResult,
|
||||
|
|
@ -164,24 +159,6 @@ def _tool_end(
|
|||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
stream_module,
|
||||
"get_flags",
|
||||
lambda: AgentFeatureFlags(enable_stream_parity_v2=True),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
stream_module,
|
||||
"get_flags",
|
||||
lambda: AgentFeatureFlags(enable_stream_parity_v2=False),
|
||||
)
|
||||
|
||||
|
||||
async def _drain(
|
||||
events: list[dict[str, Any]], state: _FakeAgentState | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
|
|
@ -253,12 +230,12 @@ class TestLegacyMatch:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parity_v2 wire format tests.
|
||||
# Tool input streaming wire format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
|
||||
async def test_idless_chunk_merging_by_index() -> None:
|
||||
"""First chunk carries id+name; later idless chunks at the same
|
||||
``index`` merge into the SAME ``tool-input-start`` ui id and emit
|
||||
one ``tool-input-delta`` per chunk."""
|
||||
|
|
@ -302,9 +279,7 @@ async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_interleaved_tool_calls_route_by_index(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
async def test_two_interleaved_tool_calls_route_by_index() -> None:
|
||||
"""Two same-name calls with distinct indices keep their deltas
|
||||
routed to the right card."""
|
||||
events = [
|
||||
|
|
@ -344,7 +319,7 @@ async def test_two_interleaved_tool_calls_route_by_index(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
|
||||
async def test_identity_stable_across_lifecycle() -> None:
|
||||
"""Whatever id ``tool-input-start`` chose must be the SAME id used
|
||||
on ``tool-input-available`` AND ``tool-output-available``."""
|
||||
events = [
|
||||
|
|
@ -367,7 +342,7 @@ async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
|
||||
async def test_no_duplicate_tool_input_start() -> None:
|
||||
"""When the chunk-emission loop already fired ``tool-input-start``
|
||||
for this run, ``on_tool_start`` MUST NOT emit a second one."""
|
||||
events = [
|
||||
|
|
@ -386,9 +361,7 @@ async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_text_closes_before_early_tool_input_start(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
async def test_active_text_closes_before_early_tool_input_start() -> None:
|
||||
"""Streaming a text-delta then a tool-call chunk in subsequent
|
||||
chunks: the wire MUST contain ``text-end`` before the FIRST
|
||||
``tool-input-start`` (clean part boundary on the frontend)."""
|
||||
|
|
@ -409,9 +382,7 @@ async def test_active_text_closes_before_early_tool_input_start(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_text_and_tool_chunk_preserve_order(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
async def test_mixed_text_and_tool_chunk_preserve_order() -> None:
|
||||
"""One AIMessageChunk that carries BOTH ``text`` content AND
|
||||
``tool_call_chunks`` should emit the text delta FIRST, then close
|
||||
text, then ``tool-input-start``+``tool-input-delta``."""
|
||||
|
|
@ -441,45 +412,7 @@ async def test_mixed_text_and_tool_chunk_preserve_order(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parity_v2_off_preserves_legacy_shape(
|
||||
parity_v2_off: None,
|
||||
) -> None:
|
||||
"""When the flag is OFF, no deltas are emitted and the ``toolCallId``
|
||||
is ``call_<run_id>`` (NOT the lc id)."""
|
||||
events = [
|
||||
_model_stream(
|
||||
tool_call_chunks=[
|
||||
{"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0}
|
||||
]
|
||||
),
|
||||
_tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}),
|
||||
_tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"),
|
||||
]
|
||||
payloads = await _drain(events)
|
||||
|
||||
assert _of_type(payloads, "tool-input-delta") == []
|
||||
starts = _of_type(payloads, "tool-input-start")
|
||||
assert len(starts) == 1
|
||||
assert starts[0]["toolCallId"].startswith("call_run-A")
|
||||
# No ``langchainToolCallId`` propagation on ``tool-input-start`` in
|
||||
# legacy mode (the start event fires before the ToolMessage is
|
||||
# available, so we can't extract the authoritative LangChain id yet).
|
||||
assert "langchainToolCallId" not in starts[0]
|
||||
output = _of_type(payloads, "tool-output-available")
|
||||
assert output[0]["toolCallId"].startswith("call_run-A")
|
||||
# ``tool-output-available`` MUST carry ``langchainToolCallId`` even
|
||||
# in legacy mode: the chat tool card uses it to backfill the
|
||||
# LangChain id and join against the ``data-action-log`` SSE event
|
||||
# (keyed by ``lc_tool_call_id``) so the inline Revert button can
|
||||
# light up. Sourced from the returned ``ToolMessage.tool_call_id``,
|
||||
# which is populated regardless of feature-flag state.
|
||||
assert output[0]["langchainToolCallId"] == "lc-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_append_prevents_stale_id_reuse(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
async def test_skip_append_prevents_stale_id_reuse() -> None:
|
||||
"""Two same-name tools: the SECOND tool's ``langchainToolCallId``
|
||||
must NOT come from the first tool's chunk (``pending_tool_call_chunks``
|
||||
must stay empty for indexed-registered chunks)."""
|
||||
|
|
@ -506,9 +439,7 @@ async def test_skip_append_prevents_stale_id_reuse(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registration_waits_for_both_id_and_name(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
async def test_registration_waits_for_both_id_and_name() -> None:
|
||||
"""An id-only chunk (no name yet) must NOT emit ``tool-input-start``."""
|
||||
events = [
|
||||
_model_stream(
|
||||
|
|
@ -520,12 +451,9 @@ async def test_registration_waits_for_both_id_and_name(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unmatched_fallback_still_attaches_lc_id(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
"""parity_v2 ON, but the provider didn't include an ``index``: the
|
||||
legacy fallback path must still emit ``tool-input-start`` with the
|
||||
matching ``langchainToolCallId``."""
|
||||
async def test_unmatched_fallback_still_attaches_lc_id() -> None:
|
||||
"""When the provider omits chunk ``index``, buffered chunks still get a
|
||||
``tool-input-start`` with the matching ``langchainToolCallId``."""
|
||||
events = [
|
||||
# No index on the chunk → not registered into index_to_meta;
|
||||
# falls through to ``pending_tool_call_chunks`` so the legacy
|
||||
|
|
@ -542,9 +470,7 @@ async def test_unmatched_fallback_still_attaches_lc_id(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interrupt_request_uses_task_that_contains_interrupt(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
async def test_interrupt_request_uses_task_that_contains_interrupt() -> None:
|
||||
interrupt_payload = {
|
||||
"type": "calendar_event_create",
|
||||
"action": {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue