Merge pull request #1357 from CREDO23/feature/multi-agent

[Feature] Multi-agent chat: hierarchical timeline, live subagent streaming, and inline HITL approvals
This commit is contained in:
Rohan Verma 2026-05-09 16:13:04 -07:00 committed by GitHub
commit 28a02a9143
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
232 changed files with 9014 additions and 4055 deletions

View file

@ -31,7 +31,6 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"SURFSENSE_ENABLE_ACTION_LOG",
"SURFSENSE_ENABLE_REVERT_ROUTE",
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
"SURFSENSE_ENABLE_PLUGIN_LOADER",
"SURFSENSE_ENABLE_OTEL",
"SURFSENSE_ENABLE_AGENT_CACHE",
@ -61,7 +60,6 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) ->
assert flags.enable_kb_planner_runnable is True
assert flags.enable_action_log is True
assert flags.enable_revert_route is True
assert flags.enable_stream_parity_v2 is True
assert flags.enable_plugin_loader is False
assert flags.enable_otel is False
# Phase 2: agent cache is now default-on (the prerequisite tool
@ -127,7 +125,6 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
"enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2",
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
"enable_otel": "SURFSENSE_ENABLE_OTEL",
}

View file

@ -0,0 +1,79 @@
"""Pin the wire compactness rule and the top-level ``emitted_by`` field name."""
from __future__ import annotations
import pytest
from app.services.streaming.emitter import (
Emitter,
attach_emitted_by,
main_emitter,
subagent_emitter,
)
pytestmark = pytest.mark.unit
def test_main_emitter_payload_contains_only_level() -> None:
payload = main_emitter().to_payload()
assert payload == {"level": "main"}
def test_subagent_emitter_payload_includes_all_set_fields() -> None:
payload = subagent_emitter(
subagent_type="deliverables",
subagent_run_id="subagent_abc",
parent_tool_call_id="call_xyz",
).to_payload()
assert payload == {
"level": "subagent",
"subagent_type": "deliverables",
"subagent_run_id": "subagent_abc",
"parent_tool_call_id": "call_xyz",
}
def test_subagent_emitter_payload_omits_unset_optional_fields() -> None:
"""parent_tool_call_id is None when the run is started outside a tool boundary."""
payload = Emitter(
level="subagent",
subagent_type="email",
subagent_run_id="subagent_1",
).to_payload()
assert "parent_tool_call_id" not in payload
assert payload["subagent_type"] == "email"
def test_extra_fields_merge_into_payload() -> None:
"""Future extension fields (e.g. lane colour, label) flow through ``extra``."""
emitter = subagent_emitter(
subagent_type="search",
subagent_run_id="r1",
extra={"label": "Web Search"},
)
assert emitter.to_payload()["label"] == "Web Search"
def test_attach_emitted_by_with_none_is_noop() -> None:
payload = {"type": "text-delta", "delta": "hi"}
result = attach_emitted_by(payload, None)
assert "emitted_by" not in result
assert result is payload
def test_attach_emitted_by_adds_payload_under_snake_case_top_level_key() -> None:
payload = {"type": "text-delta", "delta": "hi"}
attach_emitted_by(
payload,
subagent_emitter(
subagent_type="x",
subagent_run_id="y",
parent_tool_call_id="z",
),
)
assert payload["emitted_by"] == {
"level": "subagent",
"subagent_type": "x",
"subagent_run_id": "y",
"parent_tool_call_id": "z",
}

View file

@ -0,0 +1,111 @@
"""Pin the parent_ids walk + parallel sub-agent isolation that drives lane attribution."""
from __future__ import annotations
import pytest
from app.services.streaming.emitter import (
Emitter,
EmitterRegistry,
main_emitter,
subagent_emitter,
)
pytestmark = pytest.mark.unit
def _sub(run_id: str, kind: str = "deliverables") -> Emitter:
return subagent_emitter(
subagent_type=kind,
subagent_run_id=f"sub_{run_id}",
parent_tool_call_id=f"call_{run_id}",
)
def test_unregistered_event_resolves_to_main_emitter() -> None:
registry = EmitterRegistry()
resolved = registry.resolve(run_id="run_1", parent_ids=["root"])
assert resolved is main_emitter()
def test_event_owned_by_registered_run_id_returns_that_emitter() -> None:
registry = EmitterRegistry()
emitter = _sub("a")
registry.register("run_task_a", emitter)
assert registry.resolve(run_id="run_task_a", parent_ids=[]) is emitter
def test_descendant_resolves_via_parent_ids_chain() -> None:
"""A model-call event nested under the task tool inherits its sub-agent emitter."""
registry = EmitterRegistry()
emitter = _sub("a")
registry.register("run_task_a", emitter)
descendant = registry.resolve(
run_id="run_chat_model",
parent_ids=["root", "run_agent", "run_task_a"],
)
assert descendant is emitter
def test_nearest_registered_ancestor_wins_over_distant_ones() -> None:
"""Inner sub-agents owe their emitter to the nearest task tool, not the outer one."""
registry = EmitterRegistry()
outer = _sub("outer", kind="planner")
inner = _sub("inner", kind="email")
registry.register("run_outer", outer)
registry.register("run_inner", inner)
resolved = registry.resolve(
run_id="run_inner_tool",
parent_ids=["root", "run_outer", "run_inner"],
)
assert resolved is inner
def test_parallel_subagents_do_not_bleed_into_each_other() -> None:
"""Two concurrent task tools each own their own descendant events."""
registry = EmitterRegistry()
a = _sub("a", kind="search")
b = _sub("b", kind="email")
registry.register("run_task_a", a)
registry.register("run_task_b", b)
from_a = registry.resolve(run_id="x", parent_ids=["root", "run_task_a"])
from_b = registry.resolve(run_id="y", parent_ids=["root", "run_task_b"])
from_main = registry.resolve(run_id="z", parent_ids=["root"])
assert from_a is a
assert from_b is b
assert from_main is main_emitter()
def test_unregister_releases_run_id_so_descendants_fall_back_to_main() -> None:
registry = EmitterRegistry()
emitter = _sub("a")
registry.register("run_task_a", emitter)
registry.unregister("run_task_a")
assert registry.resolve(run_id="x", parent_ids=["run_task_a"]) is main_emitter()
def test_unregister_returns_the_previously_registered_emitter() -> None:
"""Lets callers emit ``data-subagent-finish`` carrying the same emitter they opened with."""
registry = EmitterRegistry()
emitter = _sub("a")
registry.register("run_task_a", emitter)
assert registry.unregister("run_task_a") is emitter
def test_has_active_subagents_tracks_open_lanes() -> None:
registry = EmitterRegistry()
assert not registry.has_active_subagents()
registry.register("run_task_a", _sub("a"))
assert registry.has_active_subagents()
registry.unregister("run_task_a")
assert not registry.has_active_subagents()
def test_empty_run_id_and_parent_ids_resolves_to_main() -> None:
"""Defensive: events without identifiers always belong to the main lane."""
registry = EmitterRegistry()
registry.register("run_task_a", _sub("a"))
assert registry.resolve(run_id=None, parent_ids=None) is main_emitter()
assert registry.resolve(run_id="", parent_ids=[]) is main_emitter()

View file

@ -0,0 +1,164 @@
"""Pin id-aware pending-interrupt lookup that replaces the buggy first-wins."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import pytest
from app.services.streaming.interrupt_correlation import (
PendingInterrupt,
first_pending_interrupt,
get_pending_interrupt_by_id,
get_pending_interrupt_for_tool_call,
list_pending_interrupts,
)
pytestmark = pytest.mark.unit
@dataclass
class _Interrupt:
value: dict[str, Any]
id: str | None = None
@dataclass
class _Task:
interrupts: tuple[_Interrupt, ...] = ()
id: str | None = None
@dataclass
class _State:
tasks: tuple[_Task, ...] = ()
interrupts: tuple[_Interrupt, ...] = ()
def _hitl(name: str, tool_call_id: str | None = None) -> dict[str, Any]:
"""Minimal LangChain HITLRequest payload for one action."""
action: dict[str, Any] = {"name": name, "args": {}}
if tool_call_id is not None:
action["tool_call_id"] = tool_call_id
return {
"action_requests": [action],
"review_configs": [{"action_name": name, "allowed_decisions": ["approve"]}],
}
def test_empty_state_has_no_pending_interrupts() -> None:
state = _State()
assert list_pending_interrupts(state) == []
assert first_pending_interrupt(state) is None
def test_single_pending_interrupt_in_task_is_returned() -> None:
state = _State(
tasks=(
_Task(
id="task_1",
interrupts=(_Interrupt(value=_hitl("send_email"), id="int_1"),),
),
)
)
pending = list_pending_interrupts(state)
assert len(pending) == 1
assert pending[0] == PendingInterrupt(
interrupt_id="int_1",
value=_hitl("send_email"),
source_task_id="task_1",
)
def test_pending_interrupts_returned_in_task_then_root_order() -> None:
"""Determinism matters: callers iterate in this order to render the UI."""
state = _State(
tasks=(
_Task(
id="task_a",
interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),),
),
_Task(
id="task_b",
interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),),
),
),
interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),),
)
pending = list_pending_interrupts(state)
ids = [p.interrupt_id for p in pending]
assert ids == ["int_a", "int_b", "int_c"]
def test_get_by_id_finds_the_right_interrupt_under_parallel_load() -> None:
"""Replacing first-wins: id-aware lookup MUST pick the requested one."""
state = _State(
tasks=(
_Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)),
_Task(interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),)),
_Task(interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),)),
)
)
found = get_pending_interrupt_by_id(state, "int_b")
assert found is not None
assert found.value["action_requests"][0]["name"] == "b"
def test_get_by_id_returns_none_when_id_is_not_pending() -> None:
state = _State(
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)),)
)
assert get_pending_interrupt_by_id(state, "missing") is None
def test_get_by_tool_call_id_matches_action_request_payload() -> None:
"""HITLRequest carries ``tool_call_id`` per action; lookup uses that."""
state = _State(
tasks=(
_Task(
interrupts=(
_Interrupt(
value=_hitl("a", tool_call_id="call_xxx"), id="int_a"
),
_Interrupt(
value=_hitl("b", tool_call_id="call_yyy"), id="int_b"
),
)
),
)
)
found = get_pending_interrupt_for_tool_call(state, "call_yyy")
assert found is not None
assert found.interrupt_id == "int_b"
def test_first_pending_interrupt_matches_legacy_first_wins_behaviour() -> None:
"""Sequential-turn safety: the explicit shortcut still returns the first."""
state = _State(
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("first"), id="int_1"),)),),
interrupts=(_Interrupt(value=_hitl("second"), id="int_2"),),
)
first = first_pending_interrupt(state)
assert first is not None
assert first.interrupt_id == "int_1"
def test_interrupt_without_id_falls_back_to_none() -> None:
"""Snapshots from older LangGraph versions may omit ``id`` — preserve that."""
state = _State(
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),)
)
pending = list_pending_interrupts(state)
assert len(pending) == 1
assert pending[0].interrupt_id is None
def test_non_dict_interrupt_values_are_ignored() -> None:
"""Defensive: a non-dict value should not crash the iteration."""
class _Raw:
value = "not a dict"
state = _State(tasks=(_Task(interrupts=(_Raw(),)),)) # type: ignore[arg-type]
assert list_pending_interrupts(state) == []

View file

@ -0,0 +1,91 @@
"""Pin interrupt-payload normalisation and the optional correlation fields on the wire."""
from __future__ import annotations
import json
import pytest
from app.services.streaming.events.interrupt import (
format_interrupt_request,
normalize_interrupt_payload,
)
pytestmark = pytest.mark.unit
def _decode(frame: str) -> dict:
body = frame.removeprefix("data: ").removesuffix("\n\n")
return json.loads(body)
def test_hitlrequest_shape_is_passed_through_unchanged() -> None:
raw = {
"action_requests": [{"name": "send_email", "args": {"to": "a@b"}}],
"review_configs": [
{"action_name": "send_email", "allowed_decisions": ["approve"]}
],
}
assert normalize_interrupt_payload(raw) == raw
def test_custom_interrupt_primitive_is_converted_to_canonical_shape() -> None:
raw = {
"type": "permission",
"message": "Allow send?",
"action": {"tool": "send_email", "params": {"to": "a@b"}},
"context": {"reason": "destructive"},
}
out = normalize_interrupt_payload(raw)
assert out["action_requests"] == [
{"name": "send_email", "args": {"to": "a@b"}}
]
assert out["review_configs"] == [
{
"action_name": "send_email",
"allowed_decisions": ["approve", "edit", "reject"],
}
]
assert out["interrupt_type"] == "permission"
assert out["message"] == "Allow send?"
assert out["context"] == {"reason": "destructive"}
def test_custom_interrupt_without_message_omits_message_key() -> None:
"""Optional fields stay optional on the wire; FE does not see ``"message": None``."""
raw = {"action": {"tool": "send_email"}}
out = normalize_interrupt_payload(raw)
assert "message" not in out
def test_custom_interrupt_without_tool_falls_back_to_unknown_tool() -> None:
"""Defensive: a malformed ``action`` block must not crash the relay."""
out = normalize_interrupt_payload({"type": "x", "action": {}})
assert out["action_requests"][0]["name"] == "unknown_tool"
assert out["review_configs"][0]["action_name"] == "unknown_tool"
def test_format_interrupt_request_carries_correlation_fields_on_the_wire() -> None:
frame = format_interrupt_request(
{"action_requests": [], "review_configs": []},
interrupt_id="int_42",
pending_interrupt_count=3,
chat_turn_id="turn_99",
)
payload = _decode(frame)
assert payload["type"] == "data-interrupt-request"
inner = payload["data"]
assert inner["interrupt_id"] == "int_42"
assert inner["pending_interrupt_count"] == 3
assert inner["chat_turn_id"] == "turn_99"
def test_format_interrupt_request_omits_correlation_fields_when_unset() -> None:
"""Backward compat: legacy single-interrupt callers don't have to supply ids."""
frame = format_interrupt_request(
{"action_requests": [], "review_configs": []},
)
inner = _decode(frame)["data"]
assert "interrupt_id" not in inner
assert "pending_interrupt_count" not in inner
assert "chat_turn_id" not in inner

View file

@ -0,0 +1,142 @@
"""Pin that sub-agent emitter reaches every wire event the relay emits."""
from __future__ import annotations
import json
import pytest
from app.services.streaming.emitter import subagent_emitter
from app.services.streaming.service import StreamingService
pytestmark = pytest.mark.unit
def _decode(frame: str) -> dict:
body = frame.removeprefix("data: ").removesuffix("\n\n")
return json.loads(body)
@pytest.fixture
def service() -> StreamingService:
return StreamingService()
@pytest.fixture
def sub_emitter():
return subagent_emitter(
subagent_type="deliverables",
subagent_run_id="sub_xyz",
parent_tool_call_id="call_parent",
)
def test_text_delta_carries_subagent_emitter_on_the_wire(service, sub_emitter) -> None:
payload = _decode(service.format_text_delta("text_1", "hi", emitter=sub_emitter))
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
assert payload["delta"] == "hi"
def test_reasoning_delta_carries_subagent_emitter_on_the_wire(
service, sub_emitter
) -> None:
payload = _decode(
service.format_reasoning_delta("r_1", "thinking", emitter=sub_emitter)
)
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
def test_tool_input_start_carries_subagent_emitter_and_lc_id(
service, sub_emitter
) -> None:
payload = _decode(
service.format_tool_input_start(
"call_1",
"send_email",
langchain_tool_call_id="lc_1",
emitter=sub_emitter,
)
)
assert payload["emitted_by"]["subagent_type"] == "deliverables"
assert payload["langchainToolCallId"] == "lc_1"
assert payload["toolName"] == "send_email"
def test_tool_output_available_carries_subagent_emitter(service, sub_emitter) -> None:
payload = _decode(
service.format_tool_output_available(
"call_1", {"ok": True}, emitter=sub_emitter
)
)
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
assert payload["output"] == {"ok": True}
def test_thinking_step_carries_subagent_emitter(service, sub_emitter) -> None:
payload = _decode(
service.format_thinking_step(
step_id="s1",
title="Sending email",
status="in_progress",
emitter=sub_emitter,
)
)
assert payload["type"] == "data-thinking-step"
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
def test_action_log_carries_subagent_emitter(service, sub_emitter) -> None:
payload = _decode(
service.format_action_log(
{"id": 1, "tool_name": "send_email", "reversible": False},
emitter=sub_emitter,
)
)
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
assert payload["data"]["tool_name"] == "send_email"
def test_subagent_lifecycle_events_share_run_id_for_pairing(
service, sub_emitter
) -> None:
start = _decode(
service.format_subagent_start(
subagent_run_id="sub_xyz",
subagent_type="deliverables",
parent_tool_call_id="call_parent",
emitter=sub_emitter,
)
)
finish = _decode(
service.format_subagent_finish(
subagent_run_id="sub_xyz",
subagent_type="deliverables",
parent_tool_call_id="call_parent",
emitter=sub_emitter,
)
)
assert start["data"]["subagent_run_id"] == finish["data"]["subagent_run_id"]
assert start["type"] == "data-subagent-start"
assert finish["type"] == "data-subagent-finish"
def test_main_emitter_events_omit_emitted_by_field(service) -> None:
payload = _decode(service.format_text_delta("text_1", "hi"))
assert "emitted_by" not in payload
def test_resolve_emitter_through_service_uses_registry(service, sub_emitter) -> None:
service.emitter_registry.register("run_task_1", sub_emitter)
resolved = service.resolve_emitter(
run_id="run_chat_model",
parent_ids=["root", "run_task_1"],
)
assert resolved is sub_emitter
def test_message_id_is_assigned_on_message_start_and_reused(service) -> None:
frame = service.format_message_start()
payload = _decode(frame)
assigned = payload["messageId"]
assert assigned.startswith("msg_")
assert service.message_id == assigned

View file

@ -0,0 +1,51 @@
"""Pin the exact SSE wire bytes the FE parser depends on."""
from __future__ import annotations
import json
import pytest
from app.services.streaming.envelope import (
format_done,
format_sse,
get_response_headers,
)
pytestmark = pytest.mark.unit
class TestFormatSse:
def test_dict_payload_is_json_serialised(self) -> None:
frame = format_sse({"type": "start", "messageId": "msg_1"})
assert frame.startswith("data: ")
assert frame.endswith("\n\n")
body = frame[len("data: ") : -2]
assert json.loads(body) == {"type": "start", "messageId": "msg_1"}
def test_string_payload_is_emitted_verbatim(self) -> None:
frame = format_sse('{"already":"json"}')
assert frame == 'data: {"already":"json"}\n\n'
def test_nested_payload_round_trips(self) -> None:
payload = {
"type": "data-action-log",
"data": {"id": 7, "tool_name": "ls", "reversible": False},
}
frame = format_sse(payload)
body = frame.removeprefix("data: ").removesuffix("\n\n")
assert json.loads(body) == payload
class TestFormatDone:
def test_done_marker_is_literal(self) -> None:
assert format_done() == "data: [DONE]\n\n"
class TestResponseHeaders:
def test_headers_pin_ai_sdk_v1_protocol(self) -> None:
headers = get_response_headers()
assert headers["Content-Type"] == "text/event-stream"
assert headers["Cache-Control"] == "no-cache"
assert headers["Connection"] == "keep-alive"
assert headers["x-vercel-ai-ui-message-stream"] == "v1"

View file

@ -0,0 +1,292 @@
"""Pin Stage 1 extractions as faithful copies of the old helpers.
Extractions under ``app.tasks.chat.streaming`` are compared to
``app.tasks.chat.stream_new_chat`` helpers.
For each Stage 1 extraction we assert the new function returns the same
output as the old one for a representative input set. The moment the
two diverge - intentionally or otherwise - this file fails loudly so
the divergence is reviewed rather than shipped silently.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
import pytest
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
from app.tasks.chat.stream_new_chat import (
_classify_stream_exception as old_classify,
_emit_stream_terminal_error as old_emit_terminal_error,
_extract_chunk_parts as old_extract_chunk_parts,
_extract_resolved_file_path as old_extract_resolved_file_path,
_first_interrupt_value as old_first_interrupt_value,
_tool_output_has_error as old_tool_output_has_error,
_tool_output_to_text as old_tool_output_to_text,
)
from app.tasks.chat.streaming.errors.classifier import (
classify_stream_exception as new_classify,
)
from app.tasks.chat.streaming.errors.emitter import (
emit_stream_terminal_error as new_emit_terminal_error,
)
from app.tasks.chat.streaming.helpers.chunk_parts import (
extract_chunk_parts as new_extract_chunk_parts,
)
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
first_interrupt_value as new_first_interrupt_value,
)
from app.tasks.chat.streaming.helpers.tool_output import (
extract_resolved_file_path as new_extract_resolved_file_path,
tool_output_has_error as new_tool_output_has_error,
tool_output_to_text as new_tool_output_to_text,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------- chunk parts
@dataclass
class _Chunk:
content: Any = ""
additional_kwargs: dict[str, Any] = field(default_factory=dict)
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
_CHUNK_CASES: list[Any] = [
None,
_Chunk(content=""),
_Chunk(content="hello"),
_Chunk(content=42), # invalid type, defensively coerced to empty
_Chunk(
content=[
{"type": "text", "text": "Hello "},
{"type": "text", "text": "world"},
]
),
_Chunk(
content=[
{"type": "reasoning", "reasoning": "hmm "},
{"type": "reasoning", "text": "still"},
{"type": "text", "text": "answer"},
]
),
_Chunk(
content=[
{"type": "tool_call_chunk", "id": "c1", "name": "x", "args": "{"},
{"type": "tool_use", "id": "c2", "name": "y"},
{"type": "image_url", "url": "ignored"},
]
),
_Chunk(
content="visible",
additional_kwargs={"reasoning_content": "private"},
),
_Chunk(
tool_call_chunks=[
{"id": None, "name": None, "args": '{"a":1}', "index": 0},
{"id": "c", "name": "n", "args": "}", "index": 0},
]
),
_Chunk(
content=[{"type": "tool_call_chunk", "id": "from-block", "name": "x"}],
tool_call_chunks=[{"id": "from-attr", "name": "y"}],
),
]
@pytest.mark.parametrize("chunk", _CHUNK_CASES)
def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None:
assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk)
# ---------------------------------------------------------- interrupt inspector
@dataclass
class _Interrupt:
value: dict[str, Any]
@dataclass
class _Task:
interrupts: tuple[Any, ...] = ()
@dataclass
class _State:
tasks: tuple[Any, ...] = ()
interrupts: tuple[Any, ...] = ()
_INTERRUPT_CASES: list[Any] = [
_State(),
_State(tasks=(_Task(interrupts=(_Interrupt(value={"name": "send"}),)),)),
# Multiple tasks: must return the FIRST one in iteration order.
_State(
tasks=(
_Task(interrupts=(_Interrupt(value={"name": "first"}),)),
_Task(interrupts=(_Interrupt(value={"name": "second"}),)),
)
),
# Empty task interrupts -> falls back to root state.interrupts.
_State(
tasks=(_Task(interrupts=()),),
interrupts=(_Interrupt(value={"name": "root"}),),
),
# Interrupts as plain dicts (not wrapper objects).
_State(interrupts=({"value": {"name": "dict_root"}},)),
# A defective task whose `.interrupts` raises - must be tolerated.
_State(tasks=(object(),)),
]
@pytest.mark.parametrize("state", _INTERRUPT_CASES)
def test_first_interrupt_value_matches_old_implementation(state: Any) -> None:
assert new_first_interrupt_value(state) == old_first_interrupt_value(state)
# ----------------------------------------------------------- error classifier
def _classify_cases() -> list[Exception]:
"""Inputs that the FE depends on being mapped to specific error codes."""
return [
Exception("totally generic error"),
Exception(
'{"error":{"type":"rate_limit_error","message":"slow down"}}'
),
Exception(
'OpenrouterException - {"error":{"message":"Provider returned error",'
'"code":429}}'
),
BusyError(request_id="thread-busy-parity"),
Exception("Thread is busy with another request"),
]
@pytest.mark.parametrize("exc", _classify_cases())
def test_classify_stream_exception_matches_old_implementation(
exc: Exception,
) -> None:
new = new_classify(exc, flow_label="parity-test")
old = old_classify(exc, flow_label="parity-test")
# Strip the wall-clock retry timestamp before comparing — both
# implementations call ``time.time()`` independently and the call
# order is enough to differ by 1 ms in practice. Every other field
# in the tuple must match exactly.
new_extra = dict(new[5]) if isinstance(new[5], dict) else new[5]
old_extra = dict(old[5]) if isinstance(old[5], dict) else old[5]
if isinstance(new_extra, dict) and isinstance(old_extra, dict):
new_extra.pop("retry_after_at", None)
old_extra.pop("retry_after_at", None)
assert new[:5] == old[:5]
assert new_extra == old_extra
def test_classify_turn_cancelling_branch_parity() -> None:
"""The TURN_CANCELLING branch reads cancel state for the busy thread id;
both implementations must agree on retry-window semantics, not just the
plain THREAD_BUSY code."""
thread_id = "parity-cancelling-thread"
reset_cancel(thread_id)
request_cancel(thread_id)
exc = BusyError(request_id=thread_id)
new = new_classify(exc, flow_label="parity-test")
old = old_classify(exc, flow_label="parity-test")
assert new[0] == old[0] == "thread_busy"
assert new[1] == old[1] == "TURN_CANCELLING"
assert isinstance(new[5], dict) and isinstance(old[5], dict)
assert new[5]["retry_after_ms"] == old[5]["retry_after_ms"]
# ------------------------------------------------------------ terminal emitter
class _FakeStreamingService:
"""Duck-types ``format_error`` for both old and new emitters."""
def __init__(self) -> None:
self.calls: list[dict[str, Any]] = []
def format_error(
self, message: str, *, error_code: str, extra: dict[str, Any] | None = None
) -> str:
self.calls.append(
{"message": message, "error_code": error_code, "extra": extra}
)
return f"data: {{\"type\":\"error\",\"errorText\":\"{message}\"}}\n\n"
def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None:
"""The new emitter must produce the same SSE frame and log the same
structured payload as the old one for the same arguments."""
args: dict[str, Any] = {
"flow": "new",
"request_id": "req-parity",
"thread_id": 7,
"search_space_id": 9,
"user_id": "user-parity",
"message": "boom",
"error_kind": "server_error",
"error_code": "SERVER_ERROR",
"severity": "error",
"is_expected": False,
"extra": {"foo": "bar"},
}
new_svc = _FakeStreamingService()
old_svc = _FakeStreamingService()
with caplog.at_level(logging.ERROR):
new_frame = new_emit_terminal_error(streaming_service=new_svc, **args)
old_frame = old_emit_terminal_error(streaming_service=old_svc, **args)
assert new_frame == old_frame
assert new_svc.calls == old_svc.calls
chat_error_records = [
r for r in caplog.records if "[chat_stream_error]" in r.message
]
# One log line per emit call (two emits -> two records).
assert len(chat_error_records) == 2
# ---------------------------------------------------------------- tool output
def test_tool_output_helpers_match_old_implementation() -> None:
samples: list[Any] = [
{"result": "ok"},
{"error": "bad"},
{"result": "Error: x"},
"Error: plain",
"fine",
{"nested": {"a": 1}},
]
for s in samples:
assert new_tool_output_to_text(s) == old_tool_output_to_text(s)
assert new_tool_output_has_error(s) == old_tool_output_has_error(s)
assert new_extract_resolved_file_path(
tool_name="write_file",
tool_output={"path": " /tmp/x "},
tool_input=None,
) == old_extract_resolved_file_path(
tool_name="write_file",
tool_output={"path": " /tmp/x "},
tool_input=None,
)
assert new_extract_resolved_file_path(
tool_name="write_file",
tool_output={},
tool_input={"file_path": " /fallback "},
) == old_extract_resolved_file_path(
tool_name="write_file",
tool_output={},
tool_input={"file_path": " /fallback "},
)

View file

@ -0,0 +1,241 @@
"""Parity tests for Stage 2 extractions (tool matching, thinking step, custom events)."""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from app.tasks.chat.stream_new_chat import _legacy_match_lc_id as old_legacy_match
from app.tasks.chat.streaming.handlers.custom_events import (
handle_action_log,
handle_action_log_updated,
handle_document_created,
handle_report_progress,
)
from app.tasks.chat.streaming.helpers.tool_call_matching import (
match_buffered_langchain_tool_call_id as new_legacy_match,
)
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
from app.tasks.chat.streaming.relay.thinking_step_completion import (
complete_active_thinking_step,
)
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
pytestmark = pytest.mark.unit
def _copy_chunk_buffer(raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
return [dict(x) for x in raw]
def test_legacy_tool_call_match_matches_old_implementation() -> None:
cases: list[tuple[list[dict[str, Any]], str, str, dict[str, str]]] = [
(
[
{"name": "write_file", "id": "lc-a"},
{"name": "other", "id": "lc-b"},
],
"write_file",
"run-1",
{},
),
(
[{"name": "x", "id": None}, {"name": "y", "id": "lc-fallback"}],
"write_file",
"run-2",
{},
),
([{"name": "no_id"}], "write_file", "run-3", {}),
]
for chunks_template, tool_name, run_id, lc_map_seed in cases:
old_chunks = _copy_chunk_buffer(chunks_template)
new_chunks = _copy_chunk_buffer(chunks_template)
old_map = dict(lc_map_seed)
new_map = dict(lc_map_seed)
old_out = old_legacy_match(old_chunks, tool_name, run_id, old_map)
new_out = new_legacy_match(new_chunks, tool_name, run_id, new_map)
assert new_out == old_out
assert new_chunks == old_chunks
assert new_map == old_map
def test_emit_thinking_step_frame_invokes_builder_before_service() -> None:
order: list[str] = []
builder = MagicMock()
def on_ts(*args: Any, **kwargs: Any) -> None:
order.append("builder")
builder.on_thinking_step.side_effect = on_ts
svc = MagicMock()
def fmt(**kwargs: Any) -> str:
order.append("service")
return "frame"
svc.format_thinking_step.side_effect = fmt
out = emit_thinking_step_frame(
streaming_service=svc,
content_builder=builder,
step_id="thinking-1",
title="Working",
status="in_progress",
items=["a"],
)
assert out == "frame"
assert order == ["builder", "service"]
builder.on_thinking_step.assert_called_once()
svc.format_thinking_step.assert_called_once()
def test_emit_thinking_step_frame_skips_builder_when_none() -> None:
svc = MagicMock(return_value="x")
svc.format_thinking_step.return_value = "frame"
assert (
emit_thinking_step_frame(
streaming_service=svc,
content_builder=None,
step_id="s",
title="t",
)
== "frame"
)
svc.format_thinking_step.assert_called_once()
def test_complete_active_thinking_step_mirrors_closure_semantics() -> None:
svc = MagicMock()
svc.format_thinking_step.return_value = "done-frame"
completed: set[str] = set()
relay_state = AgentEventRelayState.for_invocation()
frame, new_id = complete_active_thinking_step(
state=relay_state,
streaming_service=svc,
content_builder=None,
last_active_step_id="thinking-1",
last_active_step_title="T",
last_active_step_items=["x"],
completed_step_ids=completed,
)
assert frame == "done-frame"
assert new_id is None
assert "thinking-1" in completed
frame2, id2 = complete_active_thinking_step(
state=relay_state,
streaming_service=svc,
content_builder=None,
last_active_step_id="thinking-1",
last_active_step_title="T",
last_active_step_items=[],
completed_step_ids=completed,
)
assert frame2 is None
assert id2 == "thinking-1"
def test_agent_event_relay_state_factory_matches_counter_rule() -> None:
s0 = AgentEventRelayState.for_invocation()
assert s0.thinking_step_counter == 0
assert s0.last_active_step_id is None
s1 = AgentEventRelayState.for_invocation(
initial_step_id="thinking-resume-1",
initial_step_title="Inherited",
initial_step_items=["Topic: X"],
)
assert s1.thinking_step_counter == 1
assert s1.last_active_step_id == "thinking-resume-1"
assert s1.next_thinking_step_id("thinking") == "thinking-2"
@pytest.mark.parametrize(
("phase", "message", "start_items", "expected_tail"),
[
(
"revising_section",
"progress line",
["Topic: Foo", "Modifying bar", "stale..."],
["Topic: Foo", "Modifying bar", "progress line"],
),
(
"other",
"phase msg",
["Topic: Foo", "old line"],
["Topic: Foo", "phase msg"],
),
],
)
def test_report_progress_items_match_reference(
phase: str,
message: str,
start_items: list[str],
expected_tail: list[str],
) -> None:
svc = MagicMock()
svc.format_thinking_step.return_value = "sse"
items = list(start_items)
frame, new_items = handle_report_progress(
{"message": message, "phase": phase},
last_active_step_id="step-1",
last_active_step_title="Report",
last_active_step_items=items,
streaming_service=svc,
content_builder=None,
)
assert frame == "sse"
assert new_items == expected_tail
kwargs = svc.format_thinking_step.call_args.kwargs
assert kwargs["items"] == expected_tail
def test_report_progress_noop_when_missing_message_or_step() -> None:
svc = MagicMock()
items = ["Topic: A"]
f1, i1 = handle_report_progress(
{"message": "", "phase": "x"},
last_active_step_id="s",
last_active_step_title="t",
last_active_step_items=items,
streaming_service=svc,
content_builder=None,
)
assert f1 is None and i1 is items
f2, i2 = handle_report_progress(
{"message": "m", "phase": "x"},
last_active_step_id=None,
last_active_step_title="t",
last_active_step_items=items,
streaming_service=svc,
content_builder=None,
)
assert f2 is None and i2 is items
def test_document_action_handlers_match_format_data_guards() -> None:
svc = MagicMock()
svc.format_data.return_value = "data-frame"
assert handle_document_created({}, streaming_service=svc) is None
assert handle_document_created({"id": 0}, streaming_service=svc) is None
handle_document_created({"id": 42, "title": "x"}, streaming_service=svc)
svc.format_data.assert_called_with(
"documents-updated", {"action": "created", "document": {"id": 42, "title": "x"}}
)
svc.reset_mock()
assert handle_action_log({"id": None}, streaming_service=svc) is None
handle_action_log({"id": 1}, streaming_service=svc)
svc.format_data.assert_called_once_with("action-log", {"id": 1})
svc.reset_mock()
assert handle_action_log_updated({"id": None}, streaming_service=svc) is None
handle_action_log_updated({"id": 2}, streaming_service=svc)
svc.format_data.assert_called_once_with("action-log-updated", {"id": 2})

View file

@ -0,0 +1,116 @@
"""Tests for ``stream_output`` (LangGraph events → SSE)."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import pytest
from app.tasks.chat.streaming.graph_stream import stream_output
from app.tasks.chat.streaming.graph_stream.result import StreamingResult
pytestmark = pytest.mark.unit
@dataclass
class _Chunk:
content: Any = ""
additional_kwargs: dict[str, Any] = field(default_factory=dict)
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
class _StreamingService:
def __init__(self) -> None:
self._text_idx = 0
def generate_text_id(self) -> str:
self._text_idx += 1
return f"text-{self._text_idx}"
def format_text_start(self, text_id: str) -> str:
return f"text_start:{text_id}"
def format_text_delta(self, text_id: str, text: str) -> str:
return f"text_delta:{text_id}:{text}"
def format_text_end(self, text_id: str) -> str:
return f"text_end:{text_id}"
class _Agent:
def __init__(self, events: list[dict[str, Any]]) -> None:
self.events = list(events)
self.calls: list[tuple[Any, dict[str, Any]]] = []
async def astream_events(self, input_data: Any, **kwargs: Any):
self.calls.append((input_data, kwargs))
for event in self.events:
yield event
async def _collect(stream: Any) -> list[str]:
out: list[str] = []
async for x in stream:
out.append(x)
return out
async def test_stream_output_emits_text_lifecycle_and_updates_result() -> None:
service = _StreamingService()
agent = _Agent(
[
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content="Hello")}},
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content=" world")}},
]
)
result = StreamingResult()
frames = await _collect(
stream_output(
agent=agent,
config={"configurable": {"thread_id": "t-1"}},
input_data={"messages": []},
streaming_service=service,
result=result,
)
)
assert frames == [
"text_start:text-1",
"text_delta:text-1:Hello",
"text_delta:text-1: world",
"text_end:text-1",
]
assert result.accumulated_text == "Hello world"
assert result.agent_called_update_memory is False
async def test_stream_output_passes_runtime_context_to_agent() -> None:
service = _StreamingService()
class _ContextAwareAgent:
async def astream_events(self, input_data: Any, **kwargs: Any):
del input_data
text = "ctx-ok" if kwargs.get("context") else "ctx-missing"
yield {"event": "on_chat_model_stream", "data": {"chunk": _Chunk(text)}}
agent = _ContextAwareAgent()
result = StreamingResult()
frames = await _collect(
stream_output(
agent=agent,
config={"configurable": {"thread_id": "t-2"}},
input_data={"messages": []},
streaming_service=service,
result=result,
runtime_context={"mentioned_document_ids": [1, 2]},
)
)
assert frames == [
"text_start:text-1",
"text_delta:text-1:ctx-ok",
"text_end:text-1",
]

View file

@ -0,0 +1,69 @@
"""Unit tests for ``task_span`` open/close helpers."""
from __future__ import annotations
import pytest
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
from app.tasks.chat.streaming.relay.task_span import (
clear_task_span_if_delegating_task_ended,
ensure_pending_task_span_for_lc,
open_task_span,
)
pytestmark = pytest.mark.unit
def test_open_task_span_sets_span_and_run_id() -> None:
state = AgentEventRelayState.for_invocation()
sid = open_task_span(state, run_id="run-abc")
assert sid.startswith("spn_")
assert state.active_span_id == sid
assert state.active_task_run_id == "run-abc"
assert state.span_metadata_if_active() == {"spanId": sid}
def test_clear_ignored_for_non_task_tool() -> None:
state = AgentEventRelayState.for_invocation()
open_task_span(state, run_id="run-1")
sid = state.active_span_id
clear_task_span_if_delegating_task_ended(
state, tool_name="web_search", run_id="run-1"
)
assert state.active_span_id == sid
assert state.active_task_run_id == "run-1"
def test_clear_ignored_when_task_run_id_mismatches() -> None:
state = AgentEventRelayState.for_invocation()
open_task_span(state, run_id="run-open")
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-other")
assert state.active_span_id is not None
assert state.active_task_run_id == "run-open"
def test_clear_on_matching_task_end() -> None:
state = AgentEventRelayState.for_invocation()
open_task_span(state, run_id="run-x")
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-x")
assert state.active_span_id is None
assert state.active_task_run_id is None
assert state.span_metadata_if_active() is None
def test_clear_noop_when_no_open_span() -> None:
state = AgentEventRelayState.for_invocation()
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-x")
assert state.active_span_id is None
def test_pending_then_open_reuses_same_span_id() -> None:
state = AgentEventRelayState.for_invocation()
sid_pending = ensure_pending_task_span_for_lc(state, "lc-task-1")
assert state.pending_task_span_by_lc["lc-task-1"] == sid_pending
sid_active = open_task_span(
state, run_id="run-1", langchain_tool_call_id="lc-task-1"
)
assert sid_active == sid_pending
assert state.active_span_id == sid_pending
assert "lc-task-1" not in state.pending_task_span_by_lc

View file

@ -0,0 +1,42 @@
"""Unit tests for ``AgentEventRelayState.tool_activity_metadata``."""
from __future__ import annotations
import pytest
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
from app.tasks.chat.streaming.relay.task_span import open_task_span
pytestmark = pytest.mark.unit
def test_returns_none_when_no_span_and_no_thinking_step() -> None:
state = AgentEventRelayState.for_invocation()
assert state.tool_activity_metadata(thinking_step_id=None) is None
assert state.tool_activity_metadata(thinking_step_id="") is None
assert state.tool_activity_metadata(thinking_step_id=" ") is None
def test_thinking_step_id_only() -> None:
state = AgentEventRelayState.for_invocation()
assert state.tool_activity_metadata(thinking_step_id="thinking-3") == {
"thinkingStepId": "thinking-3",
}
def test_span_only_when_active() -> None:
state = AgentEventRelayState.for_invocation()
open_task_span(state, run_id="run-x")
assert state.tool_activity_metadata(thinking_step_id=None) == {
"spanId": state.active_span_id,
}
def test_merges_span_and_thinking_step_when_both_set() -> None:
state = AgentEventRelayState.for_invocation()
open_task_span(state, run_id="run-x")
md = state.tool_activity_metadata(thinking_step_id="thinking-7")
assert md == {
"spanId": state.active_span_id,
"thinkingStepId": "thinking-7",
}

View file

@ -15,6 +15,7 @@ import json
import pytest
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.content_builder import AssistantContentBuilder
pytestmark = pytest.mark.unit
@ -161,7 +162,7 @@ class TestToolHeavyTurn:
_assert_jsonb_safe(snap)
def test_tool_input_available_without_prior_start_creates_card(self):
# Legacy / parity_v2-OFF path: tool-input-available may be
# Late-registration: tool-input-available may be
# emitted without a prior tool-input-start (no streamed
# tool_call_chunks). The card should still be created.
b = AssistantContentBuilder()
@ -187,7 +188,7 @@ class TestToolHeavyTurn:
assert part["result"] == {"matches": 3}
def test_tool_input_start_idempotent_for_same_ui_id(self):
# parity_v2: tool-input-start can fire from BOTH the chunk
# tool-input-start can fire from BOTH the chunk
# registration path AND the canonical ``on_tool_start`` path.
# The second call must not create a duplicate part.
b = AssistantContentBuilder()
@ -231,6 +232,155 @@ class TestToolHeavyTurn:
)
# ---------------------------------------------------------------------------
# Task-span metadata on tool-call parts (JSONB persistence)
# ---------------------------------------------------------------------------
class TestToolCallSpanMetadata:
def test_input_available_merges_new_metadata_keys_after_start(self):
b = AssistantContentBuilder()
b.on_tool_input_start(
"call_t", "task", "lc_t", metadata={"spanId": "spn_1"}
)
b.on_tool_input_available(
"call_t",
"task",
{"goal": "x"},
"lc_t",
metadata={"traceId": "tr_1"},
)
part = b.snapshot()[0]
assert part["metadata"]["spanId"] == "spn_1"
assert part["metadata"]["traceId"] == "tr_1"
_assert_jsonb_safe(b.snapshot())
def test_input_available_does_not_overwrite_existing_metadata_keys(self):
b = AssistantContentBuilder()
b.on_tool_input_start(
"call_t", "task", "lc_t", metadata={"spanId": "spn_keep"}
)
b.on_tool_input_available(
"call_t", "task", {}, "lc_t", metadata={"spanId": "spn_other"}
)
assert b.snapshot()[0]["metadata"]["spanId"] == "spn_keep"
def test_late_tool_input_available_carries_metadata(self):
b = AssistantContentBuilder()
b.on_tool_input_available(
"call_l",
"grep",
{"pattern": "TODO"},
None,
metadata={"spanId": "spn_l"},
)
part = b.snapshot()[0]
assert part["metadata"] == {"spanId": "spn_l"}
_assert_jsonb_safe(b.snapshot())
def test_output_available_merges_without_clobbering_span_id(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_t", "ls", "lc", metadata={"spanId": "spn_x"})
b.on_tool_input_available("call_t", "ls", {"path": "/"}, "lc")
b.on_tool_output_available(
"call_t",
{"ok": True},
"lc",
metadata={"spanId": "spn_y", "extra": 1},
)
md = b.snapshot()[0]["metadata"]
assert md["spanId"] == "spn_x"
assert md["extra"] == 1
def test_output_available_adds_thinking_step_id_without_clobbering_span(self):
b = AssistantContentBuilder()
b.on_tool_input_start(
"call_t",
"ls",
"lc",
metadata={"spanId": "spn_x", "thinkingStepId": "thinking-3"},
)
b.on_tool_input_available("call_t", "ls", {"path": "/"}, "lc")
b.on_tool_output_available(
"call_t",
{"ok": True},
"lc",
metadata={"spanId": "spn_x", "thinkingStepId": "thinking-3"},
)
md = b.snapshot()[0]["metadata"]
assert md["spanId"] == "spn_x"
assert md["thinkingStepId"] == "thinking-3"
def test_output_available_with_none_metadata_preserves_prior(self):
b = AssistantContentBuilder()
b.on_tool_input_start("c", "ls", "lc", metadata={"spanId": "spn_1"})
b.on_tool_input_available("c", "ls", {}, "lc")
b.on_tool_output_available("c", {"r": 1}, "lc", metadata=None)
assert b.snapshot()[0]["metadata"] == {"spanId": "spn_1"}
def test_available_adds_thinking_step_id_after_chunk_only_start(self):
"""Mirrors chunk ``tool-input-start`` then ``on_tool_start`` ``available``."""
b = AssistantContentBuilder()
b.on_tool_input_start("lc_1", "ls", "lc_1", metadata={"spanId": "spn_a"})
b.on_tool_input_available(
"lc_1",
"ls",
{"path": "/"},
"lc_1",
metadata={"spanId": "spn_a", "thinkingStepId": "thinking-2"},
)
md = b.snapshot()[0]["metadata"]
assert md["spanId"] == "spn_a"
assert md["thinkingStepId"] == "thinking-2"
class TestVercelStreamingServiceToolMetadataWire:
"""SSE payloads include optional ``metadata`` for FE grouping."""
@staticmethod
def _parse_sse_data_line(raw: str) -> dict:
assert raw.startswith("data: ")
payload = raw.split("data: ", 1)[1].split("\n\n", 1)[0].strip()
return json.loads(payload)
def test_tool_input_available_includes_metadata_when_set(self):
svc = VercelStreamingService()
raw = svc.format_tool_input_available(
"id1",
"task",
{"a": 1},
langchain_tool_call_id="lc1",
metadata={"spanId": "spn_w", "thinkingStepId": "thinking-4"},
)
body = self._parse_sse_data_line(raw)
assert body["type"] == "tool-input-available"
assert body["metadata"] == {
"spanId": "spn_w",
"thinkingStepId": "thinking-4",
}
def test_tool_output_available_includes_metadata_when_set(self):
svc = VercelStreamingService()
raw = svc.format_tool_output_available(
"id1",
{"status": "completed"},
langchain_tool_call_id="lc1",
metadata={"spanId": "spn_o", "thinkingStepId": "thinking-9"},
)
body = self._parse_sse_data_line(raw)
assert body["type"] == "tool-output-available"
assert body["metadata"] == {
"spanId": "spn_o",
"thinkingStepId": "thinking-9",
}
def test_tool_input_available_omits_metadata_key_when_none(self):
svc = VercelStreamingService()
raw = svc.format_tool_input_available("id1", "ls", {})
body = self._parse_sse_data_line(raw)
assert "metadata" not in body
# ---------------------------------------------------------------------------
# Thinking steps & separators
# ---------------------------------------------------------------------------

View file

@ -1,16 +1,13 @@
"""Unit tests for live tool-call argument streaming.
Pins the wire format that ``_stream_agent_events`` emits when
``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start``
``tool-input-delta``... ``tool-input-available`` ``tool-output-available``
all keyed by the same LangChain ``tool_call.id``.
Pins the wire format that ``_stream_agent_events`` emits:
``tool-input-start`` ``tool-input-delta``... ``tool-input-available``
``tool-output-available``, keyed consistently with LangChain ``tool_call.id``
when the model streams indexed chunks.
Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to
``_stream_agent_events`` so we exercise them via the public wire output.
These tests also lock in the legacy / parity_v2-OFF behaviour so the
synthetic ``call_<run_id>`` shape stays stable for older clients.
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are internal to the
streaming layer so we assert on the public SSE payloads.
"""
from __future__ import annotations
@ -22,8 +19,6 @@ from typing import Any
import pytest
import app.tasks.chat.stream_new_chat as stream_module
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_new_chat import (
StreamResult,
@ -164,24 +159,6 @@ def _tool_end(
}
@pytest.fixture
def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
stream_module,
"get_flags",
lambda: AgentFeatureFlags(enable_stream_parity_v2=True),
)
@pytest.fixture
def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
stream_module,
"get_flags",
lambda: AgentFeatureFlags(enable_stream_parity_v2=False),
)
async def _drain(
events: list[dict[str, Any]], state: _FakeAgentState | None = None
) -> list[dict[str, Any]]:
@ -253,12 +230,12 @@ class TestLegacyMatch:
# ---------------------------------------------------------------------------
# parity_v2 wire format tests.
# Tool input streaming wire format
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
async def test_idless_chunk_merging_by_index() -> None:
"""First chunk carries id+name; later idless chunks at the same
``index`` merge into the SAME ``tool-input-start`` ui id and emit
one ``tool-input-delta`` per chunk."""
@ -302,9 +279,7 @@ async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
@pytest.mark.asyncio
async def test_two_interleaved_tool_calls_route_by_index(
parity_v2_on: None,
) -> None:
async def test_two_interleaved_tool_calls_route_by_index() -> None:
"""Two same-name calls with distinct indices keep their deltas
routed to the right card."""
events = [
@ -344,7 +319,7 @@ async def test_two_interleaved_tool_calls_route_by_index(
@pytest.mark.asyncio
async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
async def test_identity_stable_across_lifecycle() -> None:
"""Whatever id ``tool-input-start`` chose must be the SAME id used
on ``tool-input-available`` AND ``tool-output-available``."""
events = [
@ -367,7 +342,7 @@ async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
@pytest.mark.asyncio
async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
async def test_no_duplicate_tool_input_start() -> None:
"""When the chunk-emission loop already fired ``tool-input-start``
for this run, ``on_tool_start`` MUST NOT emit a second one."""
events = [
@ -386,9 +361,7 @@ async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
@pytest.mark.asyncio
async def test_active_text_closes_before_early_tool_input_start(
parity_v2_on: None,
) -> None:
async def test_active_text_closes_before_early_tool_input_start() -> None:
"""Streaming a text-delta then a tool-call chunk in subsequent
chunks: the wire MUST contain ``text-end`` before the FIRST
``tool-input-start`` (clean part boundary on the frontend)."""
@ -409,9 +382,7 @@ async def test_active_text_closes_before_early_tool_input_start(
@pytest.mark.asyncio
async def test_mixed_text_and_tool_chunk_preserve_order(
parity_v2_on: None,
) -> None:
async def test_mixed_text_and_tool_chunk_preserve_order() -> None:
"""One AIMessageChunk that carries BOTH ``text`` content AND
``tool_call_chunks`` should emit the text delta FIRST, then close
text, then ``tool-input-start``+``tool-input-delta``."""
@ -441,45 +412,7 @@ async def test_mixed_text_and_tool_chunk_preserve_order(
@pytest.mark.asyncio
async def test_parity_v2_off_preserves_legacy_shape(
parity_v2_off: None,
) -> None:
"""When the flag is OFF, no deltas are emitted and the ``toolCallId``
is ``call_<run_id>`` (NOT the lc id)."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0}
]
),
_tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
assert _of_type(payloads, "tool-input-delta") == []
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-A")
# No ``langchainToolCallId`` propagation on ``tool-input-start`` in
# legacy mode (the start event fires before the ToolMessage is
# available, so we can't extract the authoritative LangChain id yet).
assert "langchainToolCallId" not in starts[0]
output = _of_type(payloads, "tool-output-available")
assert output[0]["toolCallId"].startswith("call_run-A")
# ``tool-output-available`` MUST carry ``langchainToolCallId`` even
# in legacy mode: the chat tool card uses it to backfill the
# LangChain id and join against the ``data-action-log`` SSE event
# (keyed by ``lc_tool_call_id``) so the inline Revert button can
# light up. Sourced from the returned ``ToolMessage.tool_call_id``,
# which is populated regardless of feature-flag state.
assert output[0]["langchainToolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_skip_append_prevents_stale_id_reuse(
parity_v2_on: None,
) -> None:
async def test_skip_append_prevents_stale_id_reuse() -> None:
"""Two same-name tools: the SECOND tool's ``langchainToolCallId``
must NOT come from the first tool's chunk (``pending_tool_call_chunks``
must stay empty for indexed-registered chunks)."""
@ -506,9 +439,7 @@ async def test_skip_append_prevents_stale_id_reuse(
@pytest.mark.asyncio
async def test_registration_waits_for_both_id_and_name(
parity_v2_on: None,
) -> None:
async def test_registration_waits_for_both_id_and_name() -> None:
"""An id-only chunk (no name yet) must NOT emit ``tool-input-start``."""
events = [
_model_stream(
@ -520,12 +451,9 @@ async def test_registration_waits_for_both_id_and_name(
@pytest.mark.asyncio
async def test_unmatched_fallback_still_attaches_lc_id(
parity_v2_on: None,
) -> None:
"""parity_v2 ON, but the provider didn't include an ``index``: the
legacy fallback path must still emit ``tool-input-start`` with the
matching ``langchainToolCallId``."""
async def test_unmatched_fallback_still_attaches_lc_id() -> None:
"""When the provider omits chunk ``index``, buffered chunks still get a
``tool-input-start`` with the matching ``langchainToolCallId``."""
events = [
# No index on the chunk → not registered into index_to_meta;
# falls through to ``pending_tool_call_chunks`` so the legacy
@ -542,9 +470,7 @@ async def test_unmatched_fallback_still_attaches_lc_id(
@pytest.mark.asyncio
async def test_interrupt_request_uses_task_that_contains_interrupt(
parity_v2_on: None,
) -> None:
async def test_interrupt_request_uses_task_that_contains_interrupt() -> None:
interrupt_payload = {
"type": "calendar_event_create",
"action": {