mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-12 17:22:38 +02:00
Add unit tests for streaming interrupts and service propagation.
This commit is contained in:
parent
619a8362b7
commit
366122da6e
4 changed files with 397 additions and 0 deletions
|
|
@ -0,0 +1,164 @@
|
|||
"""Pin id-aware pending-interrupt lookup that replaces the buggy first-wins."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.streaming.interrupt_correlation import (
|
||||
PendingInterrupt,
|
||||
first_pending_interrupt,
|
||||
get_pending_interrupt_by_id,
|
||||
get_pending_interrupt_for_tool_call,
|
||||
list_pending_interrupts,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Interrupt:
|
||||
value: dict[str, Any]
|
||||
id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Task:
|
||||
interrupts: tuple[_Interrupt, ...] = ()
|
||||
id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _State:
|
||||
tasks: tuple[_Task, ...] = ()
|
||||
interrupts: tuple[_Interrupt, ...] = ()
|
||||
|
||||
|
||||
def _hitl(name: str, tool_call_id: str | None = None) -> dict[str, Any]:
|
||||
"""Minimal LangChain HITLRequest payload for one action."""
|
||||
action: dict[str, Any] = {"name": name, "args": {}}
|
||||
if tool_call_id is not None:
|
||||
action["tool_call_id"] = tool_call_id
|
||||
return {
|
||||
"action_requests": [action],
|
||||
"review_configs": [{"action_name": name, "allowed_decisions": ["approve"]}],
|
||||
}
|
||||
|
||||
|
||||
def test_empty_state_has_no_pending_interrupts() -> None:
|
||||
state = _State()
|
||||
assert list_pending_interrupts(state) == []
|
||||
assert first_pending_interrupt(state) is None
|
||||
|
||||
|
||||
def test_single_pending_interrupt_in_task_is_returned() -> None:
|
||||
state = _State(
|
||||
tasks=(
|
||||
_Task(
|
||||
id="task_1",
|
||||
interrupts=(_Interrupt(value=_hitl("send_email"), id="int_1"),),
|
||||
),
|
||||
)
|
||||
)
|
||||
pending = list_pending_interrupts(state)
|
||||
assert len(pending) == 1
|
||||
assert pending[0] == PendingInterrupt(
|
||||
interrupt_id="int_1",
|
||||
value=_hitl("send_email"),
|
||||
source_task_id="task_1",
|
||||
)
|
||||
|
||||
|
||||
def test_pending_interrupts_returned_in_task_then_root_order() -> None:
|
||||
"""Determinism matters: callers iterate in this order to render the UI."""
|
||||
state = _State(
|
||||
tasks=(
|
||||
_Task(
|
||||
id="task_a",
|
||||
interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),),
|
||||
),
|
||||
_Task(
|
||||
id="task_b",
|
||||
interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),),
|
||||
),
|
||||
),
|
||||
interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),),
|
||||
)
|
||||
pending = list_pending_interrupts(state)
|
||||
ids = [p.interrupt_id for p in pending]
|
||||
assert ids == ["int_a", "int_b", "int_c"]
|
||||
|
||||
|
||||
def test_get_by_id_finds_the_right_interrupt_under_parallel_load() -> None:
|
||||
"""Replacing first-wins: id-aware lookup MUST pick the requested one."""
|
||||
state = _State(
|
||||
tasks=(
|
||||
_Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)),
|
||||
_Task(interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),)),
|
||||
_Task(interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),)),
|
||||
)
|
||||
)
|
||||
found = get_pending_interrupt_by_id(state, "int_b")
|
||||
assert found is not None
|
||||
assert found.value["action_requests"][0]["name"] == "b"
|
||||
|
||||
|
||||
def test_get_by_id_returns_none_when_id_is_not_pending() -> None:
|
||||
state = _State(
|
||||
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)),)
|
||||
)
|
||||
assert get_pending_interrupt_by_id(state, "missing") is None
|
||||
|
||||
|
||||
def test_get_by_tool_call_id_matches_action_request_payload() -> None:
|
||||
"""HITLRequest carries ``tool_call_id`` per action; lookup uses that."""
|
||||
state = _State(
|
||||
tasks=(
|
||||
_Task(
|
||||
interrupts=(
|
||||
_Interrupt(
|
||||
value=_hitl("a", tool_call_id="call_xxx"), id="int_a"
|
||||
),
|
||||
_Interrupt(
|
||||
value=_hitl("b", tool_call_id="call_yyy"), id="int_b"
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
found = get_pending_interrupt_for_tool_call(state, "call_yyy")
|
||||
assert found is not None
|
||||
assert found.interrupt_id == "int_b"
|
||||
|
||||
|
||||
def test_first_pending_interrupt_matches_legacy_first_wins_behaviour() -> None:
|
||||
"""Sequential-turn safety: the explicit shortcut still returns the first."""
|
||||
state = _State(
|
||||
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("first"), id="int_1"),)),),
|
||||
interrupts=(_Interrupt(value=_hitl("second"), id="int_2"),),
|
||||
)
|
||||
first = first_pending_interrupt(state)
|
||||
assert first is not None
|
||||
assert first.interrupt_id == "int_1"
|
||||
|
||||
|
||||
def test_interrupt_without_id_falls_back_to_none() -> None:
|
||||
"""Snapshots from older LangGraph versions may omit ``id`` — preserve that."""
|
||||
state = _State(
|
||||
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),)
|
||||
)
|
||||
pending = list_pending_interrupts(state)
|
||||
assert len(pending) == 1
|
||||
assert pending[0].interrupt_id is None
|
||||
|
||||
|
||||
def test_non_dict_interrupt_values_are_ignored() -> None:
|
||||
"""Defensive: a non-dict value should not crash the iteration."""
|
||||
|
||||
class _Raw:
|
||||
value = "not a dict"
|
||||
|
||||
state = _State(tasks=(_Task(interrupts=(_Raw(),)),)) # type: ignore[arg-type]
|
||||
assert list_pending_interrupts(state) == []
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
"""Pin interrupt-payload normalisation and the optional correlation fields on the wire."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.streaming.events.interrupt import (
|
||||
format_interrupt_request,
|
||||
normalize_interrupt_payload,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _decode(frame: str) -> dict:
|
||||
body = frame.removeprefix("data: ").removesuffix("\n\n")
|
||||
return json.loads(body)
|
||||
|
||||
|
||||
def test_hitlrequest_shape_is_passed_through_unchanged() -> None:
|
||||
raw = {
|
||||
"action_requests": [{"name": "send_email", "args": {"to": "a@b"}}],
|
||||
"review_configs": [
|
||||
{"action_name": "send_email", "allowed_decisions": ["approve"]}
|
||||
],
|
||||
}
|
||||
assert normalize_interrupt_payload(raw) == raw
|
||||
|
||||
|
||||
def test_custom_interrupt_primitive_is_converted_to_canonical_shape() -> None:
|
||||
raw = {
|
||||
"type": "permission",
|
||||
"message": "Allow send?",
|
||||
"action": {"tool": "send_email", "params": {"to": "a@b"}},
|
||||
"context": {"reason": "destructive"},
|
||||
}
|
||||
out = normalize_interrupt_payload(raw)
|
||||
assert out["action_requests"] == [
|
||||
{"name": "send_email", "args": {"to": "a@b"}}
|
||||
]
|
||||
assert out["review_configs"] == [
|
||||
{
|
||||
"action_name": "send_email",
|
||||
"allowed_decisions": ["approve", "edit", "reject"],
|
||||
}
|
||||
]
|
||||
assert out["interrupt_type"] == "permission"
|
||||
assert out["message"] == "Allow send?"
|
||||
assert out["context"] == {"reason": "destructive"}
|
||||
|
||||
|
||||
def test_custom_interrupt_without_message_omits_message_key() -> None:
|
||||
"""Optional fields stay optional on the wire; FE does not see ``"message": None``."""
|
||||
raw = {"action": {"tool": "send_email"}}
|
||||
out = normalize_interrupt_payload(raw)
|
||||
assert "message" not in out
|
||||
|
||||
|
||||
def test_custom_interrupt_without_tool_falls_back_to_unknown_tool() -> None:
|
||||
"""Defensive: a malformed ``action`` block must not crash the relay."""
|
||||
out = normalize_interrupt_payload({"type": "x", "action": {}})
|
||||
assert out["action_requests"][0]["name"] == "unknown_tool"
|
||||
assert out["review_configs"][0]["action_name"] == "unknown_tool"
|
||||
|
||||
|
||||
def test_format_interrupt_request_carries_correlation_fields_on_the_wire() -> None:
|
||||
frame = format_interrupt_request(
|
||||
{"action_requests": [], "review_configs": []},
|
||||
interrupt_id="int_42",
|
||||
pending_interrupt_count=3,
|
||||
chat_turn_id="turn_99",
|
||||
)
|
||||
payload = _decode(frame)
|
||||
assert payload["type"] == "data-interrupt-request"
|
||||
inner = payload["data"]
|
||||
assert inner["interrupt_id"] == "int_42"
|
||||
assert inner["pending_interrupt_count"] == 3
|
||||
assert inner["chat_turn_id"] == "turn_99"
|
||||
|
||||
|
||||
def test_format_interrupt_request_omits_correlation_fields_when_unset() -> None:
|
||||
"""Backward compat: legacy single-interrupt callers don't have to supply ids."""
|
||||
frame = format_interrupt_request(
|
||||
{"action_requests": [], "review_configs": []},
|
||||
)
|
||||
inner = _decode(frame)["data"]
|
||||
assert "interrupt_id" not in inner
|
||||
assert "pending_interrupt_count" not in inner
|
||||
assert "chat_turn_id" not in inner
|
||||
|
|
@ -0,0 +1,142 @@
|
|||
"""Pin that sub-agent emitter reaches every wire event the relay emits."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.streaming.emitter import subagent_emitter
|
||||
from app.services.streaming.service import StreamingService
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _decode(frame: str) -> dict:
|
||||
body = frame.removeprefix("data: ").removesuffix("\n\n")
|
||||
return json.loads(body)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service() -> StreamingService:
|
||||
return StreamingService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sub_emitter():
|
||||
return subagent_emitter(
|
||||
subagent_type="deliverables",
|
||||
subagent_run_id="sub_xyz",
|
||||
parent_tool_call_id="call_parent",
|
||||
)
|
||||
|
||||
|
||||
def test_text_delta_carries_subagent_emitter_on_the_wire(service, sub_emitter) -> None:
|
||||
payload = _decode(service.format_text_delta("text_1", "hi", emitter=sub_emitter))
|
||||
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||
assert payload["delta"] == "hi"
|
||||
|
||||
|
||||
def test_reasoning_delta_carries_subagent_emitter_on_the_wire(
|
||||
service, sub_emitter
|
||||
) -> None:
|
||||
payload = _decode(
|
||||
service.format_reasoning_delta("r_1", "thinking", emitter=sub_emitter)
|
||||
)
|
||||
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||
|
||||
|
||||
def test_tool_input_start_carries_subagent_emitter_and_lc_id(
|
||||
service, sub_emitter
|
||||
) -> None:
|
||||
payload = _decode(
|
||||
service.format_tool_input_start(
|
||||
"call_1",
|
||||
"send_email",
|
||||
langchain_tool_call_id="lc_1",
|
||||
emitter=sub_emitter,
|
||||
)
|
||||
)
|
||||
assert payload["emitted_by"]["subagent_type"] == "deliverables"
|
||||
assert payload["langchainToolCallId"] == "lc_1"
|
||||
assert payload["toolName"] == "send_email"
|
||||
|
||||
|
||||
def test_tool_output_available_carries_subagent_emitter(service, sub_emitter) -> None:
|
||||
payload = _decode(
|
||||
service.format_tool_output_available(
|
||||
"call_1", {"ok": True}, emitter=sub_emitter
|
||||
)
|
||||
)
|
||||
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||
assert payload["output"] == {"ok": True}
|
||||
|
||||
|
||||
def test_thinking_step_carries_subagent_emitter(service, sub_emitter) -> None:
|
||||
payload = _decode(
|
||||
service.format_thinking_step(
|
||||
step_id="s1",
|
||||
title="Sending email",
|
||||
status="in_progress",
|
||||
emitter=sub_emitter,
|
||||
)
|
||||
)
|
||||
assert payload["type"] == "data-thinking-step"
|
||||
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||
|
||||
|
||||
def test_action_log_carries_subagent_emitter(service, sub_emitter) -> None:
|
||||
payload = _decode(
|
||||
service.format_action_log(
|
||||
{"id": 1, "tool_name": "send_email", "reversible": False},
|
||||
emitter=sub_emitter,
|
||||
)
|
||||
)
|
||||
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
|
||||
assert payload["data"]["tool_name"] == "send_email"
|
||||
|
||||
|
||||
def test_subagent_lifecycle_events_share_run_id_for_pairing(
|
||||
service, sub_emitter
|
||||
) -> None:
|
||||
start = _decode(
|
||||
service.format_subagent_start(
|
||||
subagent_run_id="sub_xyz",
|
||||
subagent_type="deliverables",
|
||||
parent_tool_call_id="call_parent",
|
||||
emitter=sub_emitter,
|
||||
)
|
||||
)
|
||||
finish = _decode(
|
||||
service.format_subagent_finish(
|
||||
subagent_run_id="sub_xyz",
|
||||
subagent_type="deliverables",
|
||||
parent_tool_call_id="call_parent",
|
||||
emitter=sub_emitter,
|
||||
)
|
||||
)
|
||||
assert start["data"]["subagent_run_id"] == finish["data"]["subagent_run_id"]
|
||||
assert start["type"] == "data-subagent-start"
|
||||
assert finish["type"] == "data-subagent-finish"
|
||||
|
||||
|
||||
def test_main_emitter_events_omit_emitted_by_field(service) -> None:
|
||||
payload = _decode(service.format_text_delta("text_1", "hi"))
|
||||
assert "emitted_by" not in payload
|
||||
|
||||
|
||||
def test_resolve_emitter_through_service_uses_registry(service, sub_emitter) -> None:
|
||||
service.emitter_registry.register("run_task_1", sub_emitter)
|
||||
resolved = service.resolve_emitter(
|
||||
run_id="run_chat_model",
|
||||
parent_ids=["root", "run_task_1"],
|
||||
)
|
||||
assert resolved is sub_emitter
|
||||
|
||||
|
||||
def test_message_id_is_assigned_on_message_start_and_reused(service) -> None:
|
||||
frame = service.format_message_start()
|
||||
payload = _decode(frame)
|
||||
assigned = payload["messageId"]
|
||||
assert assigned.startswith("msg_")
|
||||
assert service.message_id == assigned
|
||||
Loading…
Add table
Add a link
Reference in a new issue