mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-13 01:32:40 +02:00
Merge pull request #1357 from CREDO23/feature/multi-agent
[Feature] Multi-agent chat: hierarchical timeline, live subagent streaming, and inline HITL approvals
This commit is contained in:
commit
28a02a9143
232 changed files with 9014 additions and 4055 deletions
|
|
@ -31,7 +31,6 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
|
||||
"SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
"SURFSENSE_ENABLE_OTEL",
|
||||
"SURFSENSE_ENABLE_AGENT_CACHE",
|
||||
|
|
@ -61,7 +60,6 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) ->
|
|||
assert flags.enable_kb_planner_runnable is True
|
||||
assert flags.enable_action_log is True
|
||||
assert flags.enable_revert_route is True
|
||||
assert flags.enable_stream_parity_v2 is True
|
||||
assert flags.enable_plugin_loader is False
|
||||
assert flags.enable_otel is False
|
||||
# Phase 2: agent cache is now default-on (the prerequisite tool
|
||||
|
|
@ -127,7 +125,6 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
|
|||
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2",
|
||||
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
"enable_otel": "SURFSENSE_ENABLE_OTEL",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
"""Pin the wire compactness rule and the top-level ``emitted_by`` field name."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.streaming.emitter import (
|
||||
Emitter,
|
||||
attach_emitted_by,
|
||||
main_emitter,
|
||||
subagent_emitter,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_main_emitter_payload_contains_only_level() -> None:
|
||||
payload = main_emitter().to_payload()
|
||||
assert payload == {"level": "main"}
|
||||
|
||||
|
||||
def test_subagent_emitter_payload_includes_all_set_fields() -> None:
|
||||
payload = subagent_emitter(
|
||||
subagent_type="deliverables",
|
||||
subagent_run_id="subagent_abc",
|
||||
parent_tool_call_id="call_xyz",
|
||||
).to_payload()
|
||||
assert payload == {
|
||||
"level": "subagent",
|
||||
"subagent_type": "deliverables",
|
||||
"subagent_run_id": "subagent_abc",
|
||||
"parent_tool_call_id": "call_xyz",
|
||||
}
|
||||
|
||||
|
||||
def test_subagent_emitter_payload_omits_unset_optional_fields() -> None:
|
||||
"""parent_tool_call_id is None when the run is started outside a tool boundary."""
|
||||
payload = Emitter(
|
||||
level="subagent",
|
||||
subagent_type="email",
|
||||
subagent_run_id="subagent_1",
|
||||
).to_payload()
|
||||
assert "parent_tool_call_id" not in payload
|
||||
assert payload["subagent_type"] == "email"
|
||||
|
||||
|
||||
def test_extra_fields_merge_into_payload() -> None:
|
||||
"""Future extension fields (e.g. lane colour, label) flow through ``extra``."""
|
||||
emitter = subagent_emitter(
|
||||
subagent_type="search",
|
||||
subagent_run_id="r1",
|
||||
extra={"label": "Web Search"},
|
||||
)
|
||||
assert emitter.to_payload()["label"] == "Web Search"
|
||||
|
||||
|
||||
def test_attach_emitted_by_with_none_is_noop() -> None:
|
||||
payload = {"type": "text-delta", "delta": "hi"}
|
||||
result = attach_emitted_by(payload, None)
|
||||
assert "emitted_by" not in result
|
||||
assert result is payload
|
||||
|
||||
|
||||
def test_attach_emitted_by_adds_payload_under_snake_case_top_level_key() -> None:
|
||||
payload = {"type": "text-delta", "delta": "hi"}
|
||||
attach_emitted_by(
|
||||
payload,
|
||||
subagent_emitter(
|
||||
subagent_type="x",
|
||||
subagent_run_id="y",
|
||||
parent_tool_call_id="z",
|
||||
),
|
||||
)
|
||||
assert payload["emitted_by"] == {
|
||||
"level": "subagent",
|
||||
"subagent_type": "x",
|
||||
"subagent_run_id": "y",
|
||||
"parent_tool_call_id": "z",
|
||||
}
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
"""Pin the parent_ids walk + parallel sub-agent isolation that drives lane attribution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.streaming.emitter import (
|
||||
Emitter,
|
||||
EmitterRegistry,
|
||||
main_emitter,
|
||||
subagent_emitter,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _sub(run_id: str, kind: str = "deliverables") -> Emitter:
|
||||
return subagent_emitter(
|
||||
subagent_type=kind,
|
||||
subagent_run_id=f"sub_{run_id}",
|
||||
parent_tool_call_id=f"call_{run_id}",
|
||||
)
|
||||
|
||||
|
||||
def test_unregistered_event_resolves_to_main_emitter() -> None:
|
||||
registry = EmitterRegistry()
|
||||
resolved = registry.resolve(run_id="run_1", parent_ids=["root"])
|
||||
assert resolved is main_emitter()
|
||||
|
||||
|
||||
def test_event_owned_by_registered_run_id_returns_that_emitter() -> None:
|
||||
registry = EmitterRegistry()
|
||||
emitter = _sub("a")
|
||||
registry.register("run_task_a", emitter)
|
||||
assert registry.resolve(run_id="run_task_a", parent_ids=[]) is emitter
|
||||
|
||||
|
||||
def test_descendant_resolves_via_parent_ids_chain() -> None:
|
||||
"""A model-call event nested under the task tool inherits its sub-agent emitter."""
|
||||
registry = EmitterRegistry()
|
||||
emitter = _sub("a")
|
||||
registry.register("run_task_a", emitter)
|
||||
descendant = registry.resolve(
|
||||
run_id="run_chat_model",
|
||||
parent_ids=["root", "run_agent", "run_task_a"],
|
||||
)
|
||||
assert descendant is emitter
|
||||
|
||||
|
||||
def test_nearest_registered_ancestor_wins_over_distant_ones() -> None:
|
||||
"""Inner sub-agents owe their emitter to the nearest task tool, not the outer one."""
|
||||
registry = EmitterRegistry()
|
||||
outer = _sub("outer", kind="planner")
|
||||
inner = _sub("inner", kind="email")
|
||||
registry.register("run_outer", outer)
|
||||
registry.register("run_inner", inner)
|
||||
resolved = registry.resolve(
|
||||
run_id="run_inner_tool",
|
||||
parent_ids=["root", "run_outer", "run_inner"],
|
||||
)
|
||||
assert resolved is inner
|
||||
|
||||
|
||||
def test_parallel_subagents_do_not_bleed_into_each_other() -> None:
|
||||
"""Two concurrent task tools each own their own descendant events."""
|
||||
registry = EmitterRegistry()
|
||||
a = _sub("a", kind="search")
|
||||
b = _sub("b", kind="email")
|
||||
registry.register("run_task_a", a)
|
||||
registry.register("run_task_b", b)
|
||||
|
||||
from_a = registry.resolve(run_id="x", parent_ids=["root", "run_task_a"])
|
||||
from_b = registry.resolve(run_id="y", parent_ids=["root", "run_task_b"])
|
||||
from_main = registry.resolve(run_id="z", parent_ids=["root"])
|
||||
|
||||
assert from_a is a
|
||||
assert from_b is b
|
||||
assert from_main is main_emitter()
|
||||
|
||||
|
||||
def test_unregister_releases_run_id_so_descendants_fall_back_to_main() -> None:
|
||||
registry = EmitterRegistry()
|
||||
emitter = _sub("a")
|
||||
registry.register("run_task_a", emitter)
|
||||
registry.unregister("run_task_a")
|
||||
assert registry.resolve(run_id="x", parent_ids=["run_task_a"]) is main_emitter()
|
||||
|
||||
|
||||
def test_unregister_returns_the_previously_registered_emitter() -> None:
|
||||
"""Lets callers emit ``data-subagent-finish`` carrying the same emitter they opened with."""
|
||||
registry = EmitterRegistry()
|
||||
emitter = _sub("a")
|
||||
registry.register("run_task_a", emitter)
|
||||
assert registry.unregister("run_task_a") is emitter
|
||||
|
||||
|
||||
def test_has_active_subagents_tracks_open_lanes() -> None:
|
||||
registry = EmitterRegistry()
|
||||
assert not registry.has_active_subagents()
|
||||
registry.register("run_task_a", _sub("a"))
|
||||
assert registry.has_active_subagents()
|
||||
registry.unregister("run_task_a")
|
||||
assert not registry.has_active_subagents()
|
||||
|
||||
|
||||
def test_empty_run_id_and_parent_ids_resolves_to_main() -> None:
|
||||
"""Defensive: events without identifiers always belong to the main lane."""
|
||||
registry = EmitterRegistry()
|
||||
registry.register("run_task_a", _sub("a"))
|
||||
assert registry.resolve(run_id=None, parent_ids=None) is main_emitter()
|
||||
assert registry.resolve(run_id="", parent_ids=[]) is main_emitter()
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
"""Pin id-aware pending-interrupt lookup that replaces the buggy first-wins."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.streaming.interrupt_correlation import (
|
||||
PendingInterrupt,
|
||||
first_pending_interrupt,
|
||||
get_pending_interrupt_by_id,
|
||||
get_pending_interrupt_for_tool_call,
|
||||
list_pending_interrupts,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Interrupt:
|
||||
value: dict[str, Any]
|
||||
id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Task:
|
||||
interrupts: tuple[_Interrupt, ...] = ()
|
||||
id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _State:
|
||||
tasks: tuple[_Task, ...] = ()
|
||||
interrupts: tuple[_Interrupt, ...] = ()
|
||||
|
||||
|
||||
def _hitl(name: str, tool_call_id: str | None = None) -> dict[str, Any]:
|
||||
"""Minimal LangChain HITLRequest payload for one action."""
|
||||
action: dict[str, Any] = {"name": name, "args": {}}
|
||||
if tool_call_id is not None:
|
||||
action["tool_call_id"] = tool_call_id
|
||||
return {
|
||||
"action_requests": [action],
|
||||
"review_configs": [{"action_name": name, "allowed_decisions": ["approve"]}],
|
||||
}
|
||||
|
||||
|
||||
def test_empty_state_has_no_pending_interrupts() -> None:
|
||||
state = _State()
|
||||
assert list_pending_interrupts(state) == []
|
||||
assert first_pending_interrupt(state) is None
|
||||
|
||||
|
||||
def test_single_pending_interrupt_in_task_is_returned() -> None:
|
||||
state = _State(
|
||||
tasks=(
|
||||
_Task(
|
||||
id="task_1",
|
||||
interrupts=(_Interrupt(value=_hitl("send_email"), id="int_1"),),
|
||||
),
|
||||
)
|
||||
)
|
||||
pending = list_pending_interrupts(state)
|
||||
assert len(pending) == 1
|
||||
assert pending[0] == PendingInterrupt(
|
||||
interrupt_id="int_1",
|
||||
value=_hitl("send_email"),
|
||||
source_task_id="task_1",
|
||||
)
|
||||
|
||||
|
||||
def test_pending_interrupts_returned_in_task_then_root_order() -> None:
|
||||
"""Determinism matters: callers iterate in this order to render the UI."""
|
||||
state = _State(
|
||||
tasks=(
|
||||
_Task(
|
||||
id="task_a",
|
||||
interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),),
|
||||
),
|
||||
_Task(
|
||||
id="task_b",
|
||||
interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),),
|
||||
),
|
||||
),
|
||||
interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),),
|
||||
)
|
||||
pending = list_pending_interrupts(state)
|
||||
ids = [p.interrupt_id for p in pending]
|
||||
assert ids == ["int_a", "int_b", "int_c"]
|
||||
|
||||
|
||||
def test_get_by_id_finds_the_right_interrupt_under_parallel_load() -> None:
|
||||
"""Replacing first-wins: id-aware lookup MUST pick the requested one."""
|
||||
state = _State(
|
||||
tasks=(
|
||||
_Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)),
|
||||
_Task(interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),)),
|
||||
_Task(interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),)),
|
||||
)
|
||||
)
|
||||
found = get_pending_interrupt_by_id(state, "int_b")
|
||||
assert found is not None
|
||||
assert found.value["action_requests"][0]["name"] == "b"
|
||||
|
||||
|
||||
def test_get_by_id_returns_none_when_id_is_not_pending() -> None:
|
||||
state = _State(
|
||||
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)),)
|
||||
)
|
||||
assert get_pending_interrupt_by_id(state, "missing") is None
|
||||
|
||||
|
||||
def test_get_by_tool_call_id_matches_action_request_payload() -> None:
|
||||
"""HITLRequest carries ``tool_call_id`` per action; lookup uses that."""
|
||||
state = _State(
|
||||
tasks=(
|
||||
_Task(
|
||||
interrupts=(
|
||||
_Interrupt(
|
||||
value=_hitl("a", tool_call_id="call_xxx"), id="int_a"
|
||||
),
|
||||
_Interrupt(
|
||||
value=_hitl("b", tool_call_id="call_yyy"), id="int_b"
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
found = get_pending_interrupt_for_tool_call(state, "call_yyy")
|
||||
assert found is not None
|
||||
assert found.interrupt_id == "int_b"
|
||||
|
||||
|
||||
def test_first_pending_interrupt_matches_legacy_first_wins_behaviour() -> None:
|
||||
"""Sequential-turn safety: the explicit shortcut still returns the first."""
|
||||
state = _State(
|
||||
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("first"), id="int_1"),)),),
|
||||
interrupts=(_Interrupt(value=_hitl("second"), id="int_2"),),
|
||||
)
|
||||
first = first_pending_interrupt(state)
|
||||
assert first is not None
|
||||
assert first.interrupt_id == "int_1"
|
||||
|
||||
|
||||
def test_interrupt_without_id_falls_back_to_none() -> None:
|
||||
"""Snapshots from older LangGraph versions may omit ``id`` — preserve that."""
|
||||
state = _State(
|
||||
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),)
|
||||
)
|
||||
pending = list_pending_interrupts(state)
|
||||
assert len(pending) == 1
|
||||
assert pending[0].interrupt_id is None
|
||||
|
||||
|
||||
def test_non_dict_interrupt_values_are_ignored() -> None:
|
||||
"""Defensive: a non-dict value should not crash the iteration."""
|
||||
|
||||
class _Raw:
|
||||
value = "not a dict"
|
||||
|
||||
state = _State(tasks=(_Task(interrupts=(_Raw(),)),)) # type: ignore[arg-type]
|
||||
assert list_pending_interrupts(state) == []
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
"""Pin interrupt-payload normalisation and the optional correlation fields on the wire."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.streaming.events.interrupt import (
|
||||
format_interrupt_request,
|
||||
normalize_interrupt_payload,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _decode(frame: str) -> dict:
|
||||
body = frame.removeprefix("data: ").removesuffix("\n\n")
|
||||
return json.loads(body)
|
||||
|
||||
|
||||
def test_hitlrequest_shape_is_passed_through_unchanged() -> None:
|
||||
raw = {
|
||||
"action_requests": [{"name": "send_email", "args": {"to": "a@b"}}],
|
||||
"review_configs": [
|
||||
{"action_name": "send_email", "allowed_decisions": ["approve"]}
|
||||
],
|
||||
}
|
||||
assert normalize_interrupt_payload(raw) == raw
|
||||
|
||||
|
||||
def test_custom_interrupt_primitive_is_converted_to_canonical_shape() -> None:
|
||||
raw = {
|
||||
"type": "permission",
|
||||
"message": "Allow send?",
|
||||
"action": {"tool": "send_email", "params": {"to": "a@b"}},
|
||||
"context": {"reason": "destructive"},
|
||||
}
|
||||
out = normalize_interrupt_payload(raw)
|
||||
assert out["action_requests"] == [
|
||||
{"name": "send_email", "args": {"to": "a@b"}}
|
||||
]
|
||||
assert out["review_configs"] == [
|
||||
{
|
||||
"action_name": "send_email",
|
||||
"allowed_decisions": ["approve", "edit", "reject"],
|
||||
}
|
||||
]
|
||||
assert out["interrupt_type"] == "permission"
|
||||
assert out["message"] == "Allow send?"
|
||||
assert out["context"] == {"reason": "destructive"}
|
||||
|
||||
|
||||
def test_custom_interrupt_without_message_omits_message_key() -> None:
|
||||
"""Optional fields stay optional on the wire; FE does not see ``"message": None``."""
|
||||
raw = {"action": {"tool": "send_email"}}
|
||||
out = normalize_interrupt_payload(raw)
|
||||
assert "message" not in out
|
||||
|
||||
|
||||
def test_custom_interrupt_without_tool_falls_back_to_unknown_tool() -> None:
|
||||
"""Defensive: a malformed ``action`` block must not crash the relay."""
|
||||
out = normalize_interrupt_payload({"type": "x", "action": {}})
|
||||
assert out["action_requests"][0]["name"] == "unknown_tool"
|
||||
assert out["review_configs"][0]["action_name"] == "unknown_tool"
|
||||
|
||||
|
||||
def test_format_interrupt_request_carries_correlation_fields_on_the_wire() -> None:
|
||||
frame = format_interrupt_request(
|
||||
{"action_requests": [], "review_configs": []},
|
||||
interrupt_id="int_42",
|
||||
pending_interrupt_count=3,
|
||||
chat_turn_id="turn_99",
|
||||
)
|
||||
payload = _decode(frame)
|
||||
assert payload["type"] == "data-interrupt-request"
|
||||
inner = payload["data"]
|
||||
assert inner["interrupt_id"] == "int_42"
|
||||
assert inner["pending_interrupt_count"] == 3
|
||||
assert inner["chat_turn_id"] == "turn_99"
|
||||
|
||||
|
||||
def test_format_interrupt_request_omits_correlation_fields_when_unset() -> None:
|
||||
"""Backward compat: legacy single-interrupt callers don't have to supply ids."""
|
||||
frame = format_interrupt_request(
|
||||
{"action_requests": [], "review_configs": []},
|
||||
)
|
||||
inner = _decode(frame)["data"]
|
||||
assert "interrupt_id" not in inner
|
||||
assert "pending_interrupt_count" not in inner
|
||||
assert "chat_turn_id" not in inner
|
||||
|
|
@ -0,0 +1,142 @@
|
|||
"""Pin that sub-agent emitter reaches every wire event the relay emits."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.streaming.emitter import subagent_emitter
|
||||
from app.services.streaming.service import StreamingService
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _decode(frame: str) -> dict:
|
||||
body = frame.removeprefix("data: ").removesuffix("\n\n")
|
||||
return json.loads(body)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service() -> StreamingService:
|
||||
return StreamingService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sub_emitter():
|
||||
return subagent_emitter(
|
||||
subagent_type="deliverables",
|
||||
subagent_run_id="sub_xyz",
|
||||
parent_tool_call_id="call_parent",
|
||||
)
|
||||
|
||||
|
||||
def test_text_delta_carries_subagent_emitter_on_the_wire(service, sub_emitter) -> None:
|
||||
payload = _decode(service.format_text_delta("text_1", "hi", emitter=sub_emitter))
|
||||
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||
assert payload["delta"] == "hi"
|
||||
|
||||
|
||||
def test_reasoning_delta_carries_subagent_emitter_on_the_wire(
|
||||
service, sub_emitter
|
||||
) -> None:
|
||||
payload = _decode(
|
||||
service.format_reasoning_delta("r_1", "thinking", emitter=sub_emitter)
|
||||
)
|
||||
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||
|
||||
|
||||
def test_tool_input_start_carries_subagent_emitter_and_lc_id(
|
||||
service, sub_emitter
|
||||
) -> None:
|
||||
payload = _decode(
|
||||
service.format_tool_input_start(
|
||||
"call_1",
|
||||
"send_email",
|
||||
langchain_tool_call_id="lc_1",
|
||||
emitter=sub_emitter,
|
||||
)
|
||||
)
|
||||
assert payload["emitted_by"]["subagent_type"] == "deliverables"
|
||||
assert payload["langchainToolCallId"] == "lc_1"
|
||||
assert payload["toolName"] == "send_email"
|
||||
|
||||
|
||||
def test_tool_output_available_carries_subagent_emitter(service, sub_emitter) -> None:
|
||||
payload = _decode(
|
||||
service.format_tool_output_available(
|
||||
"call_1", {"ok": True}, emitter=sub_emitter
|
||||
)
|
||||
)
|
||||
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||
assert payload["output"] == {"ok": True}
|
||||
|
||||
|
||||
def test_thinking_step_carries_subagent_emitter(service, sub_emitter) -> None:
|
||||
payload = _decode(
|
||||
service.format_thinking_step(
|
||||
step_id="s1",
|
||||
title="Sending email",
|
||||
status="in_progress",
|
||||
emitter=sub_emitter,
|
||||
)
|
||||
)
|
||||
assert payload["type"] == "data-thinking-step"
|
||||
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||
|
||||
|
||||
def test_action_log_carries_subagent_emitter(service, sub_emitter) -> None:
|
||||
payload = _decode(
|
||||
service.format_action_log(
|
||||
{"id": 1, "tool_name": "send_email", "reversible": False},
|
||||
emitter=sub_emitter,
|
||||
)
|
||||
)
|
||||
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||
assert payload["data"]["tool_name"] == "send_email"
|
||||
|
||||
|
||||
def test_subagent_lifecycle_events_share_run_id_for_pairing(
|
||||
service, sub_emitter
|
||||
) -> None:
|
||||
start = _decode(
|
||||
service.format_subagent_start(
|
||||
subagent_run_id="sub_xyz",
|
||||
subagent_type="deliverables",
|
||||
parent_tool_call_id="call_parent",
|
||||
emitter=sub_emitter,
|
||||
)
|
||||
)
|
||||
finish = _decode(
|
||||
service.format_subagent_finish(
|
||||
subagent_run_id="sub_xyz",
|
||||
subagent_type="deliverables",
|
||||
parent_tool_call_id="call_parent",
|
||||
emitter=sub_emitter,
|
||||
)
|
||||
)
|
||||
assert start["data"]["subagent_run_id"] == finish["data"]["subagent_run_id"]
|
||||
assert start["type"] == "data-subagent-start"
|
||||
assert finish["type"] == "data-subagent-finish"
|
||||
|
||||
|
||||
def test_main_emitter_events_omit_emitted_by_field(service) -> None:
|
||||
payload = _decode(service.format_text_delta("text_1", "hi"))
|
||||
assert "emitted_by" not in payload
|
||||
|
||||
|
||||
def test_resolve_emitter_through_service_uses_registry(service, sub_emitter) -> None:
|
||||
service.emitter_registry.register("run_task_1", sub_emitter)
|
||||
resolved = service.resolve_emitter(
|
||||
run_id="run_chat_model",
|
||||
parent_ids=["root", "run_task_1"],
|
||||
)
|
||||
assert resolved is sub_emitter
|
||||
|
||||
|
||||
def test_message_id_is_assigned_on_message_start_and_reused(service) -> None:
|
||||
frame = service.format_message_start()
|
||||
payload = _decode(frame)
|
||||
assigned = payload["messageId"]
|
||||
assert assigned.startswith("msg_")
|
||||
assert service.message_id == assigned
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
"""Pin the exact SSE wire bytes the FE parser depends on."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.streaming.envelope import (
|
||||
format_done,
|
||||
format_sse,
|
||||
get_response_headers,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestFormatSse:
|
||||
def test_dict_payload_is_json_serialised(self) -> None:
|
||||
frame = format_sse({"type": "start", "messageId": "msg_1"})
|
||||
assert frame.startswith("data: ")
|
||||
assert frame.endswith("\n\n")
|
||||
body = frame[len("data: ") : -2]
|
||||
assert json.loads(body) == {"type": "start", "messageId": "msg_1"}
|
||||
|
||||
def test_string_payload_is_emitted_verbatim(self) -> None:
|
||||
frame = format_sse('{"already":"json"}')
|
||||
assert frame == 'data: {"already":"json"}\n\n'
|
||||
|
||||
def test_nested_payload_round_trips(self) -> None:
|
||||
payload = {
|
||||
"type": "data-action-log",
|
||||
"data": {"id": 7, "tool_name": "ls", "reversible": False},
|
||||
}
|
||||
frame = format_sse(payload)
|
||||
body = frame.removeprefix("data: ").removesuffix("\n\n")
|
||||
assert json.loads(body) == payload
|
||||
|
||||
|
||||
class TestFormatDone:
|
||||
def test_done_marker_is_literal(self) -> None:
|
||||
assert format_done() == "data: [DONE]\n\n"
|
||||
|
||||
|
||||
class TestResponseHeaders:
|
||||
def test_headers_pin_ai_sdk_v1_protocol(self) -> None:
|
||||
headers = get_response_headers()
|
||||
assert headers["Content-Type"] == "text/event-stream"
|
||||
assert headers["Cache-Control"] == "no-cache"
|
||||
assert headers["Connection"] == "keep-alive"
|
||||
assert headers["x-vercel-ai-ui-message-stream"] == "v1"
|
||||
|
|
@ -0,0 +1,292 @@
|
|||
"""Pin Stage 1 extractions as faithful copies of the old helpers.
|
||||
|
||||
Extractions under ``app.tasks.chat.streaming`` are compared to
|
||||
``app.tasks.chat.stream_new_chat`` helpers.
|
||||
For each Stage 1 extraction we assert the new function returns the same
|
||||
output as the old one for a representative input set. The moment the
|
||||
two diverge - intentionally or otherwise - this file fails loudly so
|
||||
the divergence is reviewed rather than shipped silently.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.errors import BusyError
|
||||
from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
_classify_stream_exception as old_classify,
|
||||
_emit_stream_terminal_error as old_emit_terminal_error,
|
||||
_extract_chunk_parts as old_extract_chunk_parts,
|
||||
_extract_resolved_file_path as old_extract_resolved_file_path,
|
||||
_first_interrupt_value as old_first_interrupt_value,
|
||||
_tool_output_has_error as old_tool_output_has_error,
|
||||
_tool_output_to_text as old_tool_output_to_text,
|
||||
)
|
||||
from app.tasks.chat.streaming.errors.classifier import (
|
||||
classify_stream_exception as new_classify,
|
||||
)
|
||||
from app.tasks.chat.streaming.errors.emitter import (
|
||||
emit_stream_terminal_error as new_emit_terminal_error,
|
||||
)
|
||||
from app.tasks.chat.streaming.helpers.chunk_parts import (
|
||||
extract_chunk_parts as new_extract_chunk_parts,
|
||||
)
|
||||
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
|
||||
first_interrupt_value as new_first_interrupt_value,
|
||||
)
|
||||
from app.tasks.chat.streaming.helpers.tool_output import (
|
||||
extract_resolved_file_path as new_extract_resolved_file_path,
|
||||
tool_output_has_error as new_tool_output_has_error,
|
||||
tool_output_to_text as new_tool_output_to_text,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- chunk parts
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Chunk:
|
||||
content: Any = ""
|
||||
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
_CHUNK_CASES: list[Any] = [
|
||||
None,
|
||||
_Chunk(content=""),
|
||||
_Chunk(content="hello"),
|
||||
_Chunk(content=42), # invalid type, defensively coerced to empty
|
||||
_Chunk(
|
||||
content=[
|
||||
{"type": "text", "text": "Hello "},
|
||||
{"type": "text", "text": "world"},
|
||||
]
|
||||
),
|
||||
_Chunk(
|
||||
content=[
|
||||
{"type": "reasoning", "reasoning": "hmm "},
|
||||
{"type": "reasoning", "text": "still"},
|
||||
{"type": "text", "text": "answer"},
|
||||
]
|
||||
),
|
||||
_Chunk(
|
||||
content=[
|
||||
{"type": "tool_call_chunk", "id": "c1", "name": "x", "args": "{"},
|
||||
{"type": "tool_use", "id": "c2", "name": "y"},
|
||||
{"type": "image_url", "url": "ignored"},
|
||||
]
|
||||
),
|
||||
_Chunk(
|
||||
content="visible",
|
||||
additional_kwargs={"reasoning_content": "private"},
|
||||
),
|
||||
_Chunk(
|
||||
tool_call_chunks=[
|
||||
{"id": None, "name": None, "args": '{"a":1}', "index": 0},
|
||||
{"id": "c", "name": "n", "args": "}", "index": 0},
|
||||
]
|
||||
),
|
||||
_Chunk(
|
||||
content=[{"type": "tool_call_chunk", "id": "from-block", "name": "x"}],
|
||||
tool_call_chunks=[{"id": "from-attr", "name": "y"}],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("chunk", _CHUNK_CASES)
|
||||
def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None:
|
||||
assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk)
|
||||
|
||||
|
||||
# ---------------------------------------------------------- interrupt inspector
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Interrupt:
|
||||
value: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Task:
|
||||
interrupts: tuple[Any, ...] = ()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _State:
|
||||
tasks: tuple[Any, ...] = ()
|
||||
interrupts: tuple[Any, ...] = ()
|
||||
|
||||
|
||||
_INTERRUPT_CASES: list[Any] = [
|
||||
_State(),
|
||||
_State(tasks=(_Task(interrupts=(_Interrupt(value={"name": "send"}),)),)),
|
||||
# Multiple tasks: must return the FIRST one in iteration order.
|
||||
_State(
|
||||
tasks=(
|
||||
_Task(interrupts=(_Interrupt(value={"name": "first"}),)),
|
||||
_Task(interrupts=(_Interrupt(value={"name": "second"}),)),
|
||||
)
|
||||
),
|
||||
# Empty task interrupts -> falls back to root state.interrupts.
|
||||
_State(
|
||||
tasks=(_Task(interrupts=()),),
|
||||
interrupts=(_Interrupt(value={"name": "root"}),),
|
||||
),
|
||||
# Interrupts as plain dicts (not wrapper objects).
|
||||
_State(interrupts=({"value": {"name": "dict_root"}},)),
|
||||
# A defective task whose `.interrupts` raises - must be tolerated.
|
||||
_State(tasks=(object(),)),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("state", _INTERRUPT_CASES)
|
||||
def test_first_interrupt_value_matches_old_implementation(state: Any) -> None:
|
||||
assert new_first_interrupt_value(state) == old_first_interrupt_value(state)
|
||||
|
||||
|
||||
# ----------------------------------------------------------- error classifier
|
||||
|
||||
|
||||
def _classify_cases() -> list[Exception]:
|
||||
"""Inputs that the FE depends on being mapped to specific error codes."""
|
||||
return [
|
||||
Exception("totally generic error"),
|
||||
Exception(
|
||||
'{"error":{"type":"rate_limit_error","message":"slow down"}}'
|
||||
),
|
||||
Exception(
|
||||
'OpenrouterException - {"error":{"message":"Provider returned error",'
|
||||
'"code":429}}'
|
||||
),
|
||||
BusyError(request_id="thread-busy-parity"),
|
||||
Exception("Thread is busy with another request"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exc", _classify_cases())
|
||||
def test_classify_stream_exception_matches_old_implementation(
|
||||
exc: Exception,
|
||||
) -> None:
|
||||
new = new_classify(exc, flow_label="parity-test")
|
||||
old = old_classify(exc, flow_label="parity-test")
|
||||
# Strip the wall-clock retry timestamp before comparing — both
|
||||
# implementations call ``time.time()`` independently and the call
|
||||
# order is enough to differ by 1 ms in practice. Every other field
|
||||
# in the tuple must match exactly.
|
||||
new_extra = dict(new[5]) if isinstance(new[5], dict) else new[5]
|
||||
old_extra = dict(old[5]) if isinstance(old[5], dict) else old[5]
|
||||
if isinstance(new_extra, dict) and isinstance(old_extra, dict):
|
||||
new_extra.pop("retry_after_at", None)
|
||||
old_extra.pop("retry_after_at", None)
|
||||
assert new[:5] == old[:5]
|
||||
assert new_extra == old_extra
|
||||
|
||||
|
||||
def test_classify_turn_cancelling_branch_parity() -> None:
|
||||
"""The TURN_CANCELLING branch reads cancel state for the busy thread id;
|
||||
both implementations must agree on retry-window semantics, not just the
|
||||
plain THREAD_BUSY code."""
|
||||
thread_id = "parity-cancelling-thread"
|
||||
reset_cancel(thread_id)
|
||||
request_cancel(thread_id)
|
||||
exc = BusyError(request_id=thread_id)
|
||||
new = new_classify(exc, flow_label="parity-test")
|
||||
old = old_classify(exc, flow_label="parity-test")
|
||||
assert new[0] == old[0] == "thread_busy"
|
||||
assert new[1] == old[1] == "TURN_CANCELLING"
|
||||
assert isinstance(new[5], dict) and isinstance(old[5], dict)
|
||||
assert new[5]["retry_after_ms"] == old[5]["retry_after_ms"]
|
||||
|
||||
|
||||
# ------------------------------------------------------------ terminal emitter
|
||||
|
||||
|
||||
class _FakeStreamingService:
|
||||
"""Duck-types ``format_error`` for both old and new emitters."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
|
||||
def format_error(
|
||||
self, message: str, *, error_code: str, extra: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
self.calls.append(
|
||||
{"message": message, "error_code": error_code, "extra": extra}
|
||||
)
|
||||
return f"data: {{\"type\":\"error\",\"errorText\":\"{message}\"}}\n\n"
|
||||
|
||||
|
||||
def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None:
|
||||
"""The new emitter must produce the same SSE frame and log the same
|
||||
structured payload as the old one for the same arguments."""
|
||||
args: dict[str, Any] = {
|
||||
"flow": "new",
|
||||
"request_id": "req-parity",
|
||||
"thread_id": 7,
|
||||
"search_space_id": 9,
|
||||
"user_id": "user-parity",
|
||||
"message": "boom",
|
||||
"error_kind": "server_error",
|
||||
"error_code": "SERVER_ERROR",
|
||||
"severity": "error",
|
||||
"is_expected": False,
|
||||
"extra": {"foo": "bar"},
|
||||
}
|
||||
|
||||
new_svc = _FakeStreamingService()
|
||||
old_svc = _FakeStreamingService()
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
new_frame = new_emit_terminal_error(streaming_service=new_svc, **args)
|
||||
old_frame = old_emit_terminal_error(streaming_service=old_svc, **args)
|
||||
|
||||
assert new_frame == old_frame
|
||||
assert new_svc.calls == old_svc.calls
|
||||
chat_error_records = [
|
||||
r for r in caplog.records if "[chat_stream_error]" in r.message
|
||||
]
|
||||
# One log line per emit call (two emits -> two records).
|
||||
assert len(chat_error_records) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- tool output
|
||||
|
||||
|
||||
def test_tool_output_helpers_match_old_implementation() -> None:
|
||||
samples: list[Any] = [
|
||||
{"result": "ok"},
|
||||
{"error": "bad"},
|
||||
{"result": "Error: x"},
|
||||
"Error: plain",
|
||||
"fine",
|
||||
{"nested": {"a": 1}},
|
||||
]
|
||||
for s in samples:
|
||||
assert new_tool_output_to_text(s) == old_tool_output_to_text(s)
|
||||
assert new_tool_output_has_error(s) == old_tool_output_has_error(s)
|
||||
|
||||
assert new_extract_resolved_file_path(
|
||||
tool_name="write_file",
|
||||
tool_output={"path": " /tmp/x "},
|
||||
tool_input=None,
|
||||
) == old_extract_resolved_file_path(
|
||||
tool_name="write_file",
|
||||
tool_output={"path": " /tmp/x "},
|
||||
tool_input=None,
|
||||
)
|
||||
assert new_extract_resolved_file_path(
|
||||
tool_name="write_file",
|
||||
tool_output={},
|
||||
tool_input={"file_path": " /fallback "},
|
||||
) == old_extract_resolved_file_path(
|
||||
tool_name="write_file",
|
||||
tool_output={},
|
||||
tool_input={"file_path": " /fallback "},
|
||||
)
|
||||
|
|
@ -0,0 +1,241 @@
|
|||
"""Parity tests for Stage 2 extractions (tool matching, thinking step, custom events)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.tasks.chat.stream_new_chat import _legacy_match_lc_id as old_legacy_match
|
||||
from app.tasks.chat.streaming.handlers.custom_events import (
|
||||
handle_action_log,
|
||||
handle_action_log_updated,
|
||||
handle_document_created,
|
||||
handle_report_progress,
|
||||
)
|
||||
from app.tasks.chat.streaming.helpers.tool_call_matching import (
|
||||
match_buffered_langchain_tool_call_id as new_legacy_match,
|
||||
)
|
||||
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
||||
from app.tasks.chat.streaming.relay.thinking_step_completion import (
|
||||
complete_active_thinking_step,
|
||||
)
|
||||
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _copy_chunk_buffer(raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
return [dict(x) for x in raw]
|
||||
|
||||
|
||||
def test_legacy_tool_call_match_matches_old_implementation() -> None:
|
||||
cases: list[tuple[list[dict[str, Any]], str, str, dict[str, str]]] = [
|
||||
(
|
||||
[
|
||||
{"name": "write_file", "id": "lc-a"},
|
||||
{"name": "other", "id": "lc-b"},
|
||||
],
|
||||
"write_file",
|
||||
"run-1",
|
||||
{},
|
||||
),
|
||||
(
|
||||
[{"name": "x", "id": None}, {"name": "y", "id": "lc-fallback"}],
|
||||
"write_file",
|
||||
"run-2",
|
||||
{},
|
||||
),
|
||||
([{"name": "no_id"}], "write_file", "run-3", {}),
|
||||
]
|
||||
for chunks_template, tool_name, run_id, lc_map_seed in cases:
|
||||
old_chunks = _copy_chunk_buffer(chunks_template)
|
||||
new_chunks = _copy_chunk_buffer(chunks_template)
|
||||
old_map = dict(lc_map_seed)
|
||||
new_map = dict(lc_map_seed)
|
||||
old_out = old_legacy_match(old_chunks, tool_name, run_id, old_map)
|
||||
new_out = new_legacy_match(new_chunks, tool_name, run_id, new_map)
|
||||
assert new_out == old_out
|
||||
assert new_chunks == old_chunks
|
||||
assert new_map == old_map
|
||||
|
||||
|
||||
def test_emit_thinking_step_frame_invokes_builder_before_service() -> None:
|
||||
order: list[str] = []
|
||||
builder = MagicMock()
|
||||
|
||||
def on_ts(*args: Any, **kwargs: Any) -> None:
|
||||
order.append("builder")
|
||||
|
||||
builder.on_thinking_step.side_effect = on_ts
|
||||
|
||||
svc = MagicMock()
|
||||
|
||||
def fmt(**kwargs: Any) -> str:
|
||||
order.append("service")
|
||||
return "frame"
|
||||
|
||||
svc.format_thinking_step.side_effect = fmt
|
||||
|
||||
out = emit_thinking_step_frame(
|
||||
streaming_service=svc,
|
||||
content_builder=builder,
|
||||
step_id="thinking-1",
|
||||
title="Working",
|
||||
status="in_progress",
|
||||
items=["a"],
|
||||
)
|
||||
assert out == "frame"
|
||||
assert order == ["builder", "service"]
|
||||
builder.on_thinking_step.assert_called_once()
|
||||
svc.format_thinking_step.assert_called_once()
|
||||
|
||||
|
||||
def test_emit_thinking_step_frame_skips_builder_when_none() -> None:
|
||||
svc = MagicMock(return_value="x")
|
||||
svc.format_thinking_step.return_value = "frame"
|
||||
assert (
|
||||
emit_thinking_step_frame(
|
||||
streaming_service=svc,
|
||||
content_builder=None,
|
||||
step_id="s",
|
||||
title="t",
|
||||
)
|
||||
== "frame"
|
||||
)
|
||||
svc.format_thinking_step.assert_called_once()
|
||||
|
||||
|
||||
def test_complete_active_thinking_step_mirrors_closure_semantics() -> None:
|
||||
svc = MagicMock()
|
||||
svc.format_thinking_step.return_value = "done-frame"
|
||||
completed: set[str] = set()
|
||||
relay_state = AgentEventRelayState.for_invocation()
|
||||
|
||||
frame, new_id = complete_active_thinking_step(
|
||||
state=relay_state,
|
||||
streaming_service=svc,
|
||||
content_builder=None,
|
||||
last_active_step_id="thinking-1",
|
||||
last_active_step_title="T",
|
||||
last_active_step_items=["x"],
|
||||
completed_step_ids=completed,
|
||||
)
|
||||
assert frame == "done-frame"
|
||||
assert new_id is None
|
||||
assert "thinking-1" in completed
|
||||
|
||||
frame2, id2 = complete_active_thinking_step(
|
||||
state=relay_state,
|
||||
streaming_service=svc,
|
||||
content_builder=None,
|
||||
last_active_step_id="thinking-1",
|
||||
last_active_step_title="T",
|
||||
last_active_step_items=[],
|
||||
completed_step_ids=completed,
|
||||
)
|
||||
assert frame2 is None
|
||||
assert id2 == "thinking-1"
|
||||
|
||||
|
||||
def test_agent_event_relay_state_factory_matches_counter_rule() -> None:
|
||||
s0 = AgentEventRelayState.for_invocation()
|
||||
assert s0.thinking_step_counter == 0
|
||||
assert s0.last_active_step_id is None
|
||||
|
||||
s1 = AgentEventRelayState.for_invocation(
|
||||
initial_step_id="thinking-resume-1",
|
||||
initial_step_title="Inherited",
|
||||
initial_step_items=["Topic: X"],
|
||||
)
|
||||
assert s1.thinking_step_counter == 1
|
||||
assert s1.last_active_step_id == "thinking-resume-1"
|
||||
assert s1.next_thinking_step_id("thinking") == "thinking-2"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("phase", "message", "start_items", "expected_tail"),
|
||||
[
|
||||
(
|
||||
"revising_section",
|
||||
"progress line",
|
||||
["Topic: Foo", "Modifying bar", "stale..."],
|
||||
["Topic: Foo", "Modifying bar", "progress line"],
|
||||
),
|
||||
(
|
||||
"other",
|
||||
"phase msg",
|
||||
["Topic: Foo", "old line"],
|
||||
["Topic: Foo", "phase msg"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_report_progress_items_match_reference(
|
||||
phase: str,
|
||||
message: str,
|
||||
start_items: list[str],
|
||||
expected_tail: list[str],
|
||||
) -> None:
|
||||
svc = MagicMock()
|
||||
svc.format_thinking_step.return_value = "sse"
|
||||
|
||||
items = list(start_items)
|
||||
frame, new_items = handle_report_progress(
|
||||
{"message": message, "phase": phase},
|
||||
last_active_step_id="step-1",
|
||||
last_active_step_title="Report",
|
||||
last_active_step_items=items,
|
||||
streaming_service=svc,
|
||||
content_builder=None,
|
||||
)
|
||||
assert frame == "sse"
|
||||
assert new_items == expected_tail
|
||||
kwargs = svc.format_thinking_step.call_args.kwargs
|
||||
assert kwargs["items"] == expected_tail
|
||||
|
||||
|
||||
def test_report_progress_noop_when_missing_message_or_step() -> None:
|
||||
svc = MagicMock()
|
||||
items = ["Topic: A"]
|
||||
f1, i1 = handle_report_progress(
|
||||
{"message": "", "phase": "x"},
|
||||
last_active_step_id="s",
|
||||
last_active_step_title="t",
|
||||
last_active_step_items=items,
|
||||
streaming_service=svc,
|
||||
content_builder=None,
|
||||
)
|
||||
assert f1 is None and i1 is items
|
||||
|
||||
f2, i2 = handle_report_progress(
|
||||
{"message": "m", "phase": "x"},
|
||||
last_active_step_id=None,
|
||||
last_active_step_title="t",
|
||||
last_active_step_items=items,
|
||||
streaming_service=svc,
|
||||
content_builder=None,
|
||||
)
|
||||
assert f2 is None and i2 is items
|
||||
|
||||
|
||||
def test_document_action_handlers_match_format_data_guards() -> None:
|
||||
svc = MagicMock()
|
||||
svc.format_data.return_value = "data-frame"
|
||||
|
||||
assert handle_document_created({}, streaming_service=svc) is None
|
||||
assert handle_document_created({"id": 0}, streaming_service=svc) is None
|
||||
handle_document_created({"id": 42, "title": "x"}, streaming_service=svc)
|
||||
svc.format_data.assert_called_with(
|
||||
"documents-updated", {"action": "created", "document": {"id": 42, "title": "x"}}
|
||||
)
|
||||
|
||||
svc.reset_mock()
|
||||
assert handle_action_log({"id": None}, streaming_service=svc) is None
|
||||
handle_action_log({"id": 1}, streaming_service=svc)
|
||||
svc.format_data.assert_called_once_with("action-log", {"id": 1})
|
||||
|
||||
svc.reset_mock()
|
||||
assert handle_action_log_updated({"id": None}, streaming_service=svc) is None
|
||||
handle_action_log_updated({"id": 2}, streaming_service=svc)
|
||||
svc.format_data.assert_called_once_with("action-log-updated", {"id": 2})
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
"""Tests for ``stream_output`` (LangGraph events → SSE)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.tasks.chat.streaming.graph_stream import stream_output
|
||||
from app.tasks.chat.streaming.graph_stream.result import StreamingResult
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Chunk:
|
||||
content: Any = ""
|
||||
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class _StreamingService:
|
||||
def __init__(self) -> None:
|
||||
self._text_idx = 0
|
||||
|
||||
def generate_text_id(self) -> str:
|
||||
self._text_idx += 1
|
||||
return f"text-{self._text_idx}"
|
||||
|
||||
def format_text_start(self, text_id: str) -> str:
|
||||
return f"text_start:{text_id}"
|
||||
|
||||
def format_text_delta(self, text_id: str, text: str) -> str:
|
||||
return f"text_delta:{text_id}:{text}"
|
||||
|
||||
def format_text_end(self, text_id: str) -> str:
|
||||
return f"text_end:{text_id}"
|
||||
|
||||
|
||||
class _Agent:
|
||||
def __init__(self, events: list[dict[str, Any]]) -> None:
|
||||
self.events = list(events)
|
||||
self.calls: list[tuple[Any, dict[str, Any]]] = []
|
||||
|
||||
async def astream_events(self, input_data: Any, **kwargs: Any):
|
||||
self.calls.append((input_data, kwargs))
|
||||
for event in self.events:
|
||||
yield event
|
||||
|
||||
|
||||
async def _collect(stream: Any) -> list[str]:
|
||||
out: list[str] = []
|
||||
async for x in stream:
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
async def test_stream_output_emits_text_lifecycle_and_updates_result() -> None:
|
||||
service = _StreamingService()
|
||||
agent = _Agent(
|
||||
[
|
||||
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content="Hello")}},
|
||||
{"event": "on_chat_model_stream", "data": {"chunk": _Chunk(content=" world")}},
|
||||
]
|
||||
)
|
||||
result = StreamingResult()
|
||||
|
||||
frames = await _collect(
|
||||
stream_output(
|
||||
agent=agent,
|
||||
config={"configurable": {"thread_id": "t-1"}},
|
||||
input_data={"messages": []},
|
||||
streaming_service=service,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
assert frames == [
|
||||
"text_start:text-1",
|
||||
"text_delta:text-1:Hello",
|
||||
"text_delta:text-1: world",
|
||||
"text_end:text-1",
|
||||
]
|
||||
assert result.accumulated_text == "Hello world"
|
||||
assert result.agent_called_update_memory is False
|
||||
|
||||
|
||||
async def test_stream_output_passes_runtime_context_to_agent() -> None:
|
||||
service = _StreamingService()
|
||||
|
||||
class _ContextAwareAgent:
|
||||
async def astream_events(self, input_data: Any, **kwargs: Any):
|
||||
del input_data
|
||||
text = "ctx-ok" if kwargs.get("context") else "ctx-missing"
|
||||
yield {"event": "on_chat_model_stream", "data": {"chunk": _Chunk(text)}}
|
||||
|
||||
agent = _ContextAwareAgent()
|
||||
result = StreamingResult()
|
||||
|
||||
frames = await _collect(
|
||||
stream_output(
|
||||
agent=agent,
|
||||
config={"configurable": {"thread_id": "t-2"}},
|
||||
input_data={"messages": []},
|
||||
streaming_service=service,
|
||||
result=result,
|
||||
runtime_context={"mentioned_document_ids": [1, 2]},
|
||||
)
|
||||
)
|
||||
|
||||
assert frames == [
|
||||
"text_start:text-1",
|
||||
"text_delta:text-1:ctx-ok",
|
||||
"text_end:text-1",
|
||||
]
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
"""Unit tests for ``task_span`` open/close helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
||||
from app.tasks.chat.streaming.relay.task_span import (
|
||||
clear_task_span_if_delegating_task_ended,
|
||||
ensure_pending_task_span_for_lc,
|
||||
open_task_span,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_open_task_span_sets_span_and_run_id() -> None:
|
||||
state = AgentEventRelayState.for_invocation()
|
||||
sid = open_task_span(state, run_id="run-abc")
|
||||
assert sid.startswith("spn_")
|
||||
assert state.active_span_id == sid
|
||||
assert state.active_task_run_id == "run-abc"
|
||||
assert state.span_metadata_if_active() == {"spanId": sid}
|
||||
|
||||
|
||||
def test_clear_ignored_for_non_task_tool() -> None:
|
||||
state = AgentEventRelayState.for_invocation()
|
||||
open_task_span(state, run_id="run-1")
|
||||
sid = state.active_span_id
|
||||
clear_task_span_if_delegating_task_ended(
|
||||
state, tool_name="web_search", run_id="run-1"
|
||||
)
|
||||
assert state.active_span_id == sid
|
||||
assert state.active_task_run_id == "run-1"
|
||||
|
||||
|
||||
def test_clear_ignored_when_task_run_id_mismatches() -> None:
|
||||
state = AgentEventRelayState.for_invocation()
|
||||
open_task_span(state, run_id="run-open")
|
||||
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-other")
|
||||
assert state.active_span_id is not None
|
||||
assert state.active_task_run_id == "run-open"
|
||||
|
||||
|
||||
def test_clear_on_matching_task_end() -> None:
|
||||
state = AgentEventRelayState.for_invocation()
|
||||
open_task_span(state, run_id="run-x")
|
||||
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-x")
|
||||
assert state.active_span_id is None
|
||||
assert state.active_task_run_id is None
|
||||
assert state.span_metadata_if_active() is None
|
||||
|
||||
|
||||
def test_clear_noop_when_no_open_span() -> None:
|
||||
state = AgentEventRelayState.for_invocation()
|
||||
clear_task_span_if_delegating_task_ended(state, tool_name="task", run_id="run-x")
|
||||
assert state.active_span_id is None
|
||||
|
||||
|
||||
def test_pending_then_open_reuses_same_span_id() -> None:
|
||||
state = AgentEventRelayState.for_invocation()
|
||||
sid_pending = ensure_pending_task_span_for_lc(state, "lc-task-1")
|
||||
assert state.pending_task_span_by_lc["lc-task-1"] == sid_pending
|
||||
sid_active = open_task_span(
|
||||
state, run_id="run-1", langchain_tool_call_id="lc-task-1"
|
||||
)
|
||||
assert sid_active == sid_pending
|
||||
assert state.active_span_id == sid_pending
|
||||
assert "lc-task-1" not in state.pending_task_span_by_lc
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
"""Unit tests for ``AgentEventRelayState.tool_activity_metadata``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
||||
from app.tasks.chat.streaming.relay.task_span import open_task_span
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_returns_none_when_no_span_and_no_thinking_step() -> None:
|
||||
state = AgentEventRelayState.for_invocation()
|
||||
assert state.tool_activity_metadata(thinking_step_id=None) is None
|
||||
assert state.tool_activity_metadata(thinking_step_id="") is None
|
||||
assert state.tool_activity_metadata(thinking_step_id=" ") is None
|
||||
|
||||
|
||||
def test_thinking_step_id_only() -> None:
|
||||
state = AgentEventRelayState.for_invocation()
|
||||
assert state.tool_activity_metadata(thinking_step_id="thinking-3") == {
|
||||
"thinkingStepId": "thinking-3",
|
||||
}
|
||||
|
||||
|
||||
def test_span_only_when_active() -> None:
|
||||
state = AgentEventRelayState.for_invocation()
|
||||
open_task_span(state, run_id="run-x")
|
||||
assert state.tool_activity_metadata(thinking_step_id=None) == {
|
||||
"spanId": state.active_span_id,
|
||||
}
|
||||
|
||||
|
||||
def test_merges_span_and_thinking_step_when_both_set() -> None:
|
||||
state = AgentEventRelayState.for_invocation()
|
||||
open_task_span(state, run_id="run-x")
|
||||
md = state.tool_activity_metadata(thinking_step_id="thinking-7")
|
||||
assert md == {
|
||||
"spanId": state.active_span_id,
|
||||
"thinkingStepId": "thinking-7",
|
||||
}
|
||||
|
|
@ -15,6 +15,7 @@ import json
|
|||
|
||||
import pytest
|
||||
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.content_builder import AssistantContentBuilder
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
|
@ -161,7 +162,7 @@ class TestToolHeavyTurn:
|
|||
_assert_jsonb_safe(snap)
|
||||
|
||||
def test_tool_input_available_without_prior_start_creates_card(self):
|
||||
# Legacy / parity_v2-OFF path: tool-input-available may be
|
||||
# Late-registration: tool-input-available may be
|
||||
# emitted without a prior tool-input-start (no streamed
|
||||
# tool_call_chunks). The card should still be created.
|
||||
b = AssistantContentBuilder()
|
||||
|
|
@ -187,7 +188,7 @@ class TestToolHeavyTurn:
|
|||
assert part["result"] == {"matches": 3}
|
||||
|
||||
def test_tool_input_start_idempotent_for_same_ui_id(self):
|
||||
# parity_v2: tool-input-start can fire from BOTH the chunk
|
||||
# tool-input-start can fire from BOTH the chunk
|
||||
# registration path AND the canonical ``on_tool_start`` path.
|
||||
# The second call must not create a duplicate part.
|
||||
b = AssistantContentBuilder()
|
||||
|
|
@ -231,6 +232,155 @@ class TestToolHeavyTurn:
|
|||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task-span metadata on tool-call parts (JSONB persistence)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolCallSpanMetadata:
|
||||
def test_input_available_merges_new_metadata_keys_after_start(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start(
|
||||
"call_t", "task", "lc_t", metadata={"spanId": "spn_1"}
|
||||
)
|
||||
b.on_tool_input_available(
|
||||
"call_t",
|
||||
"task",
|
||||
{"goal": "x"},
|
||||
"lc_t",
|
||||
metadata={"traceId": "tr_1"},
|
||||
)
|
||||
part = b.snapshot()[0]
|
||||
assert part["metadata"]["spanId"] == "spn_1"
|
||||
assert part["metadata"]["traceId"] == "tr_1"
|
||||
_assert_jsonb_safe(b.snapshot())
|
||||
|
||||
def test_input_available_does_not_overwrite_existing_metadata_keys(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start(
|
||||
"call_t", "task", "lc_t", metadata={"spanId": "spn_keep"}
|
||||
)
|
||||
b.on_tool_input_available(
|
||||
"call_t", "task", {}, "lc_t", metadata={"spanId": "spn_other"}
|
||||
)
|
||||
assert b.snapshot()[0]["metadata"]["spanId"] == "spn_keep"
|
||||
|
||||
def test_late_tool_input_available_carries_metadata(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_available(
|
||||
"call_l",
|
||||
"grep",
|
||||
{"pattern": "TODO"},
|
||||
None,
|
||||
metadata={"spanId": "spn_l"},
|
||||
)
|
||||
part = b.snapshot()[0]
|
||||
assert part["metadata"] == {"spanId": "spn_l"}
|
||||
_assert_jsonb_safe(b.snapshot())
|
||||
|
||||
def test_output_available_merges_without_clobbering_span_id(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_t", "ls", "lc", metadata={"spanId": "spn_x"})
|
||||
b.on_tool_input_available("call_t", "ls", {"path": "/"}, "lc")
|
||||
b.on_tool_output_available(
|
||||
"call_t",
|
||||
{"ok": True},
|
||||
"lc",
|
||||
metadata={"spanId": "spn_y", "extra": 1},
|
||||
)
|
||||
md = b.snapshot()[0]["metadata"]
|
||||
assert md["spanId"] == "spn_x"
|
||||
assert md["extra"] == 1
|
||||
|
||||
def test_output_available_adds_thinking_step_id_without_clobbering_span(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start(
|
||||
"call_t",
|
||||
"ls",
|
||||
"lc",
|
||||
metadata={"spanId": "spn_x", "thinkingStepId": "thinking-3"},
|
||||
)
|
||||
b.on_tool_input_available("call_t", "ls", {"path": "/"}, "lc")
|
||||
b.on_tool_output_available(
|
||||
"call_t",
|
||||
{"ok": True},
|
||||
"lc",
|
||||
metadata={"spanId": "spn_x", "thinkingStepId": "thinking-3"},
|
||||
)
|
||||
md = b.snapshot()[0]["metadata"]
|
||||
assert md["spanId"] == "spn_x"
|
||||
assert md["thinkingStepId"] == "thinking-3"
|
||||
|
||||
def test_output_available_with_none_metadata_preserves_prior(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("c", "ls", "lc", metadata={"spanId": "spn_1"})
|
||||
b.on_tool_input_available("c", "ls", {}, "lc")
|
||||
b.on_tool_output_available("c", {"r": 1}, "lc", metadata=None)
|
||||
assert b.snapshot()[0]["metadata"] == {"spanId": "spn_1"}
|
||||
|
||||
def test_available_adds_thinking_step_id_after_chunk_only_start(self):
|
||||
"""Mirrors chunk ``tool-input-start`` then ``on_tool_start`` ``available``."""
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("lc_1", "ls", "lc_1", metadata={"spanId": "spn_a"})
|
||||
b.on_tool_input_available(
|
||||
"lc_1",
|
||||
"ls",
|
||||
{"path": "/"},
|
||||
"lc_1",
|
||||
metadata={"spanId": "spn_a", "thinkingStepId": "thinking-2"},
|
||||
)
|
||||
md = b.snapshot()[0]["metadata"]
|
||||
assert md["spanId"] == "spn_a"
|
||||
assert md["thinkingStepId"] == "thinking-2"
|
||||
|
||||
|
||||
class TestVercelStreamingServiceToolMetadataWire:
|
||||
"""SSE payloads include optional ``metadata`` for FE grouping."""
|
||||
|
||||
@staticmethod
|
||||
def _parse_sse_data_line(raw: str) -> dict:
|
||||
assert raw.startswith("data: ")
|
||||
payload = raw.split("data: ", 1)[1].split("\n\n", 1)[0].strip()
|
||||
return json.loads(payload)
|
||||
|
||||
def test_tool_input_available_includes_metadata_when_set(self):
|
||||
svc = VercelStreamingService()
|
||||
raw = svc.format_tool_input_available(
|
||||
"id1",
|
||||
"task",
|
||||
{"a": 1},
|
||||
langchain_tool_call_id="lc1",
|
||||
metadata={"spanId": "spn_w", "thinkingStepId": "thinking-4"},
|
||||
)
|
||||
body = self._parse_sse_data_line(raw)
|
||||
assert body["type"] == "tool-input-available"
|
||||
assert body["metadata"] == {
|
||||
"spanId": "spn_w",
|
||||
"thinkingStepId": "thinking-4",
|
||||
}
|
||||
|
||||
def test_tool_output_available_includes_metadata_when_set(self):
|
||||
svc = VercelStreamingService()
|
||||
raw = svc.format_tool_output_available(
|
||||
"id1",
|
||||
{"status": "completed"},
|
||||
langchain_tool_call_id="lc1",
|
||||
metadata={"spanId": "spn_o", "thinkingStepId": "thinking-9"},
|
||||
)
|
||||
body = self._parse_sse_data_line(raw)
|
||||
assert body["type"] == "tool-output-available"
|
||||
assert body["metadata"] == {
|
||||
"spanId": "spn_o",
|
||||
"thinkingStepId": "thinking-9",
|
||||
}
|
||||
|
||||
def test_tool_input_available_omits_metadata_key_when_none(self):
|
||||
svc = VercelStreamingService()
|
||||
raw = svc.format_tool_input_available("id1", "ls", {})
|
||||
body = self._parse_sse_data_line(raw)
|
||||
assert "metadata" not in body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thinking steps & separators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -1,16 +1,13 @@
|
|||
"""Unit tests for live tool-call argument streaming.
|
||||
|
||||
Pins the wire format that ``_stream_agent_events`` emits when
|
||||
``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start`` →
|
||||
``tool-input-delta``... → ``tool-input-available`` → ``tool-output-available``
|
||||
all keyed by the same LangChain ``tool_call.id``.
|
||||
Pins the wire format that ``_stream_agent_events`` emits:
|
||||
``tool-input-start`` → ``tool-input-delta``... → ``tool-input-available`` →
|
||||
``tool-output-available``, keyed consistently with LangChain ``tool_call.id``
|
||||
when the model streams indexed chunks.
|
||||
|
||||
Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and
|
||||
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to
|
||||
``_stream_agent_events`` so we exercise them via the public wire output.
|
||||
|
||||
These tests also lock in the legacy / parity_v2-OFF behaviour so the
|
||||
synthetic ``call_<run_id>`` shape stays stable for older clients.
|
||||
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are internal to the
|
||||
streaming layer so we assert on the public SSE payloads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -22,8 +19,6 @@ from typing import Any
|
|||
|
||||
import pytest
|
||||
|
||||
import app.tasks.chat.stream_new_chat as stream_module
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
StreamResult,
|
||||
|
|
@ -164,24 +159,6 @@ def _tool_end(
|
|||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
stream_module,
|
||||
"get_flags",
|
||||
lambda: AgentFeatureFlags(enable_stream_parity_v2=True),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
stream_module,
|
||||
"get_flags",
|
||||
lambda: AgentFeatureFlags(enable_stream_parity_v2=False),
|
||||
)
|
||||
|
||||
|
||||
async def _drain(
|
||||
events: list[dict[str, Any]], state: _FakeAgentState | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
|
|
@ -253,12 +230,12 @@ class TestLegacyMatch:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parity_v2 wire format tests.
|
||||
# Tool input streaming wire format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
|
||||
async def test_idless_chunk_merging_by_index() -> None:
|
||||
"""First chunk carries id+name; later idless chunks at the same
|
||||
``index`` merge into the SAME ``tool-input-start`` ui id and emit
|
||||
one ``tool-input-delta`` per chunk."""
|
||||
|
|
@ -302,9 +279,7 @@ async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_interleaved_tool_calls_route_by_index(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
async def test_two_interleaved_tool_calls_route_by_index() -> None:
|
||||
"""Two same-name calls with distinct indices keep their deltas
|
||||
routed to the right card."""
|
||||
events = [
|
||||
|
|
@ -344,7 +319,7 @@ async def test_two_interleaved_tool_calls_route_by_index(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
|
||||
async def test_identity_stable_across_lifecycle() -> None:
|
||||
"""Whatever id ``tool-input-start`` chose must be the SAME id used
|
||||
on ``tool-input-available`` AND ``tool-output-available``."""
|
||||
events = [
|
||||
|
|
@ -367,7 +342,7 @@ async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
|
||||
async def test_no_duplicate_tool_input_start() -> None:
|
||||
"""When the chunk-emission loop already fired ``tool-input-start``
|
||||
for this run, ``on_tool_start`` MUST NOT emit a second one."""
|
||||
events = [
|
||||
|
|
@ -386,9 +361,7 @@ async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_text_closes_before_early_tool_input_start(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
async def test_active_text_closes_before_early_tool_input_start() -> None:
|
||||
"""Streaming a text-delta then a tool-call chunk in subsequent
|
||||
chunks: the wire MUST contain ``text-end`` before the FIRST
|
||||
``tool-input-start`` (clean part boundary on the frontend)."""
|
||||
|
|
@ -409,9 +382,7 @@ async def test_active_text_closes_before_early_tool_input_start(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_text_and_tool_chunk_preserve_order(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
async def test_mixed_text_and_tool_chunk_preserve_order() -> None:
|
||||
"""One AIMessageChunk that carries BOTH ``text`` content AND
|
||||
``tool_call_chunks`` should emit the text delta FIRST, then close
|
||||
text, then ``tool-input-start``+``tool-input-delta``."""
|
||||
|
|
@ -441,45 +412,7 @@ async def test_mixed_text_and_tool_chunk_preserve_order(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parity_v2_off_preserves_legacy_shape(
|
||||
parity_v2_off: None,
|
||||
) -> None:
|
||||
"""When the flag is OFF, no deltas are emitted and the ``toolCallId``
|
||||
is ``call_<run_id>`` (NOT the lc id)."""
|
||||
events = [
|
||||
_model_stream(
|
||||
tool_call_chunks=[
|
||||
{"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0}
|
||||
]
|
||||
),
|
||||
_tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}),
|
||||
_tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"),
|
||||
]
|
||||
payloads = await _drain(events)
|
||||
|
||||
assert _of_type(payloads, "tool-input-delta") == []
|
||||
starts = _of_type(payloads, "tool-input-start")
|
||||
assert len(starts) == 1
|
||||
assert starts[0]["toolCallId"].startswith("call_run-A")
|
||||
# No ``langchainToolCallId`` propagation on ``tool-input-start`` in
|
||||
# legacy mode (the start event fires before the ToolMessage is
|
||||
# available, so we can't extract the authoritative LangChain id yet).
|
||||
assert "langchainToolCallId" not in starts[0]
|
||||
output = _of_type(payloads, "tool-output-available")
|
||||
assert output[0]["toolCallId"].startswith("call_run-A")
|
||||
# ``tool-output-available`` MUST carry ``langchainToolCallId`` even
|
||||
# in legacy mode: the chat tool card uses it to backfill the
|
||||
# LangChain id and join against the ``data-action-log`` SSE event
|
||||
# (keyed by ``lc_tool_call_id``) so the inline Revert button can
|
||||
# light up. Sourced from the returned ``ToolMessage.tool_call_id``,
|
||||
# which is populated regardless of feature-flag state.
|
||||
assert output[0]["langchainToolCallId"] == "lc-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_append_prevents_stale_id_reuse(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
async def test_skip_append_prevents_stale_id_reuse() -> None:
|
||||
"""Two same-name tools: the SECOND tool's ``langchainToolCallId``
|
||||
must NOT come from the first tool's chunk (``pending_tool_call_chunks``
|
||||
must stay empty for indexed-registered chunks)."""
|
||||
|
|
@ -506,9 +439,7 @@ async def test_skip_append_prevents_stale_id_reuse(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registration_waits_for_both_id_and_name(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
async def test_registration_waits_for_both_id_and_name() -> None:
|
||||
"""An id-only chunk (no name yet) must NOT emit ``tool-input-start``."""
|
||||
events = [
|
||||
_model_stream(
|
||||
|
|
@ -520,12 +451,9 @@ async def test_registration_waits_for_both_id_and_name(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unmatched_fallback_still_attaches_lc_id(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
"""parity_v2 ON, but the provider didn't include an ``index``: the
|
||||
legacy fallback path must still emit ``tool-input-start`` with the
|
||||
matching ``langchainToolCallId``."""
|
||||
async def test_unmatched_fallback_still_attaches_lc_id() -> None:
|
||||
"""When the provider omits chunk ``index``, buffered chunks still get a
|
||||
``tool-input-start`` with the matching ``langchainToolCallId``."""
|
||||
events = [
|
||||
# No index on the chunk → not registered into index_to_meta;
|
||||
# falls through to ``pending_tool_call_chunks`` so the legacy
|
||||
|
|
@ -542,9 +470,7 @@ async def test_unmatched_fallback_still_attaches_lc_id(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interrupt_request_uses_task_that_contains_interrupt(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
async def test_interrupt_request_uses_task_that_contains_interrupt() -> None:
|
||||
interrupt_payload = {
|
||||
"type": "calendar_event_create",
|
||||
"action": {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue