mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 17:22:38 +02:00
Add unit tests for streaming interrupts and service propagation.
This commit is contained in:
parent
619a8362b7
commit
366122da6e
4 changed files with 397 additions and 0 deletions
|
|
@ -0,0 +1,164 @@
|
||||||
|
"""Pin id-aware pending-interrupt lookup that replaces the buggy first-wins."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.streaming.interrupt_correlation import (
|
||||||
|
PendingInterrupt,
|
||||||
|
first_pending_interrupt,
|
||||||
|
get_pending_interrupt_by_id,
|
||||||
|
get_pending_interrupt_for_tool_call,
|
||||||
|
list_pending_interrupts,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Interrupt:
|
||||||
|
value: dict[str, Any]
|
||||||
|
id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Task:
|
||||||
|
interrupts: tuple[_Interrupt, ...] = ()
|
||||||
|
id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _State:
|
||||||
|
tasks: tuple[_Task, ...] = ()
|
||||||
|
interrupts: tuple[_Interrupt, ...] = ()
|
||||||
|
|
||||||
|
|
||||||
|
def _hitl(name: str, tool_call_id: str | None = None) -> dict[str, Any]:
|
||||||
|
"""Minimal LangChain HITLRequest payload for one action."""
|
||||||
|
action: dict[str, Any] = {"name": name, "args": {}}
|
||||||
|
if tool_call_id is not None:
|
||||||
|
action["tool_call_id"] = tool_call_id
|
||||||
|
return {
|
||||||
|
"action_requests": [action],
|
||||||
|
"review_configs": [{"action_name": name, "allowed_decisions": ["approve"]}],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_state_has_no_pending_interrupts() -> None:
|
||||||
|
state = _State()
|
||||||
|
assert list_pending_interrupts(state) == []
|
||||||
|
assert first_pending_interrupt(state) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_pending_interrupt_in_task_is_returned() -> None:
|
||||||
|
state = _State(
|
||||||
|
tasks=(
|
||||||
|
_Task(
|
||||||
|
id="task_1",
|
||||||
|
interrupts=(_Interrupt(value=_hitl("send_email"), id="int_1"),),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pending = list_pending_interrupts(state)
|
||||||
|
assert len(pending) == 1
|
||||||
|
assert pending[0] == PendingInterrupt(
|
||||||
|
interrupt_id="int_1",
|
||||||
|
value=_hitl("send_email"),
|
||||||
|
source_task_id="task_1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pending_interrupts_returned_in_task_then_root_order() -> None:
|
||||||
|
"""Determinism matters: callers iterate in this order to render the UI."""
|
||||||
|
state = _State(
|
||||||
|
tasks=(
|
||||||
|
_Task(
|
||||||
|
id="task_a",
|
||||||
|
interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),),
|
||||||
|
),
|
||||||
|
_Task(
|
||||||
|
id="task_b",
|
||||||
|
interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),),
|
||||||
|
)
|
||||||
|
pending = list_pending_interrupts(state)
|
||||||
|
ids = [p.interrupt_id for p in pending]
|
||||||
|
assert ids == ["int_a", "int_b", "int_c"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_by_id_finds_the_right_interrupt_under_parallel_load() -> None:
|
||||||
|
"""Replacing first-wins: id-aware lookup MUST pick the requested one."""
|
||||||
|
state = _State(
|
||||||
|
tasks=(
|
||||||
|
_Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)),
|
||||||
|
_Task(interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),)),
|
||||||
|
_Task(interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
found = get_pending_interrupt_by_id(state, "int_b")
|
||||||
|
assert found is not None
|
||||||
|
assert found.value["action_requests"][0]["name"] == "b"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_by_id_returns_none_when_id_is_not_pending() -> None:
|
||||||
|
state = _State(
|
||||||
|
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)),)
|
||||||
|
)
|
||||||
|
assert get_pending_interrupt_by_id(state, "missing") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_by_tool_call_id_matches_action_request_payload() -> None:
|
||||||
|
"""HITLRequest carries ``tool_call_id`` per action; lookup uses that."""
|
||||||
|
state = _State(
|
||||||
|
tasks=(
|
||||||
|
_Task(
|
||||||
|
interrupts=(
|
||||||
|
_Interrupt(
|
||||||
|
value=_hitl("a", tool_call_id="call_xxx"), id="int_a"
|
||||||
|
),
|
||||||
|
_Interrupt(
|
||||||
|
value=_hitl("b", tool_call_id="call_yyy"), id="int_b"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
found = get_pending_interrupt_for_tool_call(state, "call_yyy")
|
||||||
|
assert found is not None
|
||||||
|
assert found.interrupt_id == "int_b"
|
||||||
|
|
||||||
|
|
||||||
|
def test_first_pending_interrupt_matches_legacy_first_wins_behaviour() -> None:
|
||||||
|
"""Sequential-turn safety: the explicit shortcut still returns the first."""
|
||||||
|
state = _State(
|
||||||
|
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("first"), id="int_1"),)),),
|
||||||
|
interrupts=(_Interrupt(value=_hitl("second"), id="int_2"),),
|
||||||
|
)
|
||||||
|
first = first_pending_interrupt(state)
|
||||||
|
assert first is not None
|
||||||
|
assert first.interrupt_id == "int_1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_interrupt_without_id_falls_back_to_none() -> None:
|
||||||
|
"""Snapshots from older LangGraph versions may omit ``id`` — preserve that."""
|
||||||
|
state = _State(
|
||||||
|
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),)
|
||||||
|
)
|
||||||
|
pending = list_pending_interrupts(state)
|
||||||
|
assert len(pending) == 1
|
||||||
|
assert pending[0].interrupt_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_dict_interrupt_values_are_ignored() -> None:
|
||||||
|
"""Defensive: a non-dict value should not crash the iteration."""
|
||||||
|
|
||||||
|
class _Raw:
|
||||||
|
value = "not a dict"
|
||||||
|
|
||||||
|
state = _State(tasks=(_Task(interrupts=(_Raw(),)),)) # type: ignore[arg-type]
|
||||||
|
assert list_pending_interrupts(state) == []
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
"""Pin interrupt-payload normalisation and the optional correlation fields on the wire."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.streaming.events.interrupt import (
|
||||||
|
format_interrupt_request,
|
||||||
|
normalize_interrupt_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _decode(frame: str) -> dict:
|
||||||
|
body = frame.removeprefix("data: ").removesuffix("\n\n")
|
||||||
|
return json.loads(body)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hitlrequest_shape_is_passed_through_unchanged() -> None:
|
||||||
|
raw = {
|
||||||
|
"action_requests": [{"name": "send_email", "args": {"to": "a@b"}}],
|
||||||
|
"review_configs": [
|
||||||
|
{"action_name": "send_email", "allowed_decisions": ["approve"]}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
assert normalize_interrupt_payload(raw) == raw
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_interrupt_primitive_is_converted_to_canonical_shape() -> None:
|
||||||
|
raw = {
|
||||||
|
"type": "permission",
|
||||||
|
"message": "Allow send?",
|
||||||
|
"action": {"tool": "send_email", "params": {"to": "a@b"}},
|
||||||
|
"context": {"reason": "destructive"},
|
||||||
|
}
|
||||||
|
out = normalize_interrupt_payload(raw)
|
||||||
|
assert out["action_requests"] == [
|
||||||
|
{"name": "send_email", "args": {"to": "a@b"}}
|
||||||
|
]
|
||||||
|
assert out["review_configs"] == [
|
||||||
|
{
|
||||||
|
"action_name": "send_email",
|
||||||
|
"allowed_decisions": ["approve", "edit", "reject"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert out["interrupt_type"] == "permission"
|
||||||
|
assert out["message"] == "Allow send?"
|
||||||
|
assert out["context"] == {"reason": "destructive"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_interrupt_without_message_omits_message_key() -> None:
|
||||||
|
"""Optional fields stay optional on the wire; FE does not see ``"message": None``."""
|
||||||
|
raw = {"action": {"tool": "send_email"}}
|
||||||
|
out = normalize_interrupt_payload(raw)
|
||||||
|
assert "message" not in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_interrupt_without_tool_falls_back_to_unknown_tool() -> None:
|
||||||
|
"""Defensive: a malformed ``action`` block must not crash the relay."""
|
||||||
|
out = normalize_interrupt_payload({"type": "x", "action": {}})
|
||||||
|
assert out["action_requests"][0]["name"] == "unknown_tool"
|
||||||
|
assert out["review_configs"][0]["action_name"] == "unknown_tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_interrupt_request_carries_correlation_fields_on_the_wire() -> None:
|
||||||
|
frame = format_interrupt_request(
|
||||||
|
{"action_requests": [], "review_configs": []},
|
||||||
|
interrupt_id="int_42",
|
||||||
|
pending_interrupt_count=3,
|
||||||
|
chat_turn_id="turn_99",
|
||||||
|
)
|
||||||
|
payload = _decode(frame)
|
||||||
|
assert payload["type"] == "data-interrupt-request"
|
||||||
|
inner = payload["data"]
|
||||||
|
assert inner["interrupt_id"] == "int_42"
|
||||||
|
assert inner["pending_interrupt_count"] == 3
|
||||||
|
assert inner["chat_turn_id"] == "turn_99"
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_interrupt_request_omits_correlation_fields_when_unset() -> None:
|
||||||
|
"""Backward compat: legacy single-interrupt callers don't have to supply ids."""
|
||||||
|
frame = format_interrupt_request(
|
||||||
|
{"action_requests": [], "review_configs": []},
|
||||||
|
)
|
||||||
|
inner = _decode(frame)["data"]
|
||||||
|
assert "interrupt_id" not in inner
|
||||||
|
assert "pending_interrupt_count" not in inner
|
||||||
|
assert "chat_turn_id" not in inner
|
||||||
|
|
@ -0,0 +1,142 @@
|
||||||
|
"""Pin that sub-agent emitter reaches every wire event the relay emits."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.streaming.emitter import subagent_emitter
|
||||||
|
from app.services.streaming.service import StreamingService
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _decode(frame: str) -> dict:
|
||||||
|
body = frame.removeprefix("data: ").removesuffix("\n\n")
|
||||||
|
return json.loads(body)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service() -> StreamingService:
|
||||||
|
return StreamingService()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sub_emitter():
|
||||||
|
return subagent_emitter(
|
||||||
|
subagent_type="deliverables",
|
||||||
|
subagent_run_id="sub_xyz",
|
||||||
|
parent_tool_call_id="call_parent",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_delta_carries_subagent_emitter_on_the_wire(service, sub_emitter) -> None:
|
||||||
|
payload = _decode(service.format_text_delta("text_1", "hi", emitter=sub_emitter))
|
||||||
|
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||||
|
assert payload["delta"] == "hi"
|
||||||
|
|
||||||
|
|
||||||
|
def test_reasoning_delta_carries_subagent_emitter_on_the_wire(
|
||||||
|
service, sub_emitter
|
||||||
|
) -> None:
|
||||||
|
payload = _decode(
|
||||||
|
service.format_reasoning_delta("r_1", "thinking", emitter=sub_emitter)
|
||||||
|
)
|
||||||
|
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_input_start_carries_subagent_emitter_and_lc_id(
|
||||||
|
service, sub_emitter
|
||||||
|
) -> None:
|
||||||
|
payload = _decode(
|
||||||
|
service.format_tool_input_start(
|
||||||
|
"call_1",
|
||||||
|
"send_email",
|
||||||
|
langchain_tool_call_id="lc_1",
|
||||||
|
emitter=sub_emitter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert payload["emitted_by"]["subagent_type"] == "deliverables"
|
||||||
|
assert payload["langchainToolCallId"] == "lc_1"
|
||||||
|
assert payload["toolName"] == "send_email"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_output_available_carries_subagent_emitter(service, sub_emitter) -> None:
|
||||||
|
payload = _decode(
|
||||||
|
service.format_tool_output_available(
|
||||||
|
"call_1", {"ok": True}, emitter=sub_emitter
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||||
|
assert payload["output"] == {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_thinking_step_carries_subagent_emitter(service, sub_emitter) -> None:
|
||||||
|
payload = _decode(
|
||||||
|
service.format_thinking_step(
|
||||||
|
step_id="s1",
|
||||||
|
title="Sending email",
|
||||||
|
status="in_progress",
|
||||||
|
emitter=sub_emitter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert payload["type"] == "data-thinking-step"
|
||||||
|
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_log_carries_subagent_emitter(service, sub_emitter) -> None:
|
||||||
|
payload = _decode(
|
||||||
|
service.format_action_log(
|
||||||
|
{"id": 1, "tool_name": "send_email", "reversible": False},
|
||||||
|
emitter=sub_emitter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||||
|
assert payload["data"]["tool_name"] == "send_email"
|
||||||
|
|
||||||
|
|
||||||
|
def test_subagent_lifecycle_events_share_run_id_for_pairing(
|
||||||
|
service, sub_emitter
|
||||||
|
) -> None:
|
||||||
|
start = _decode(
|
||||||
|
service.format_subagent_start(
|
||||||
|
subagent_run_id="sub_xyz",
|
||||||
|
subagent_type="deliverables",
|
||||||
|
parent_tool_call_id="call_parent",
|
||||||
|
emitter=sub_emitter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
finish = _decode(
|
||||||
|
service.format_subagent_finish(
|
||||||
|
subagent_run_id="sub_xyz",
|
||||||
|
subagent_type="deliverables",
|
||||||
|
parent_tool_call_id="call_parent",
|
||||||
|
emitter=sub_emitter,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert start["data"]["subagent_run_id"] == finish["data"]["subagent_run_id"]
|
||||||
|
assert start["type"] == "data-subagent-start"
|
||||||
|
assert finish["type"] == "data-subagent-finish"
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_emitter_events_omit_emitted_by_field(service) -> None:
|
||||||
|
payload = _decode(service.format_text_delta("text_1", "hi"))
|
||||||
|
assert "emitted_by" not in payload
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_emitter_through_service_uses_registry(service, sub_emitter) -> None:
|
||||||
|
service.emitter_registry.register("run_task_1", sub_emitter)
|
||||||
|
resolved = service.resolve_emitter(
|
||||||
|
run_id="run_chat_model",
|
||||||
|
parent_ids=["root", "run_task_1"],
|
||||||
|
)
|
||||||
|
assert resolved is sub_emitter
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_id_is_assigned_on_message_start_and_reused(service) -> None:
|
||||||
|
frame = service.format_message_start()
|
||||||
|
payload = _decode(frame)
|
||||||
|
assigned = payload["messageId"]
|
||||||
|
assert assigned.startswith("msg_")
|
||||||
|
assert service.message_id == assigned
|
||||||
Loading…
Add table
Add a link
Reference in a new issue