From 366122da6e4568289e131e2ac20be63e8bb5bd90 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 6 May 2026 20:08:48 +0200 Subject: [PATCH] Add unit tests for streaming interrupts and service propagation. --- .../tests/unit/services/streaming/__init__.py | 0 .../streaming/test_interrupt_correlation.py | 164 ++++++++++++++++++ .../streaming/test_interrupt_events.py | 91 ++++++++++ .../test_service_emitter_propagation.py | 142 +++++++++++++++ 4 files changed, 397 insertions(+) create mode 100644 surfsense_backend/tests/unit/services/streaming/__init__.py create mode 100644 surfsense_backend/tests/unit/services/streaming/test_interrupt_correlation.py create mode 100644 surfsense_backend/tests/unit/services/streaming/test_interrupt_events.py create mode 100644 surfsense_backend/tests/unit/services/streaming/test_service_emitter_propagation.py diff --git a/surfsense_backend/tests/unit/services/streaming/__init__.py b/surfsense_backend/tests/unit/services/streaming/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/services/streaming/test_interrupt_correlation.py b/surfsense_backend/tests/unit/services/streaming/test_interrupt_correlation.py new file mode 100644 index 000000000..edf4ecb9a --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_interrupt_correlation.py @@ -0,0 +1,164 @@ +"""Pin id-aware pending-interrupt lookup that replaces the buggy first-wins.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import pytest + +from app.services.streaming.interrupt_correlation import ( + PendingInterrupt, + first_pending_interrupt, + get_pending_interrupt_by_id, + get_pending_interrupt_for_tool_call, + list_pending_interrupts, +) + +pytestmark = pytest.mark.unit + + +@dataclass +class _Interrupt: + value: dict[str, Any] + id: str | None = None + + +@dataclass +class _Task: + interrupts: tuple[_Interrupt, ...] = () + id: str | None = None + + +@dataclass +class _State: + tasks: tuple[_Task, ...] = () + interrupts: tuple[_Interrupt, ...] = () + + +def _hitl(name: str, tool_call_id: str | None = None) -> dict[str, Any]: + """Minimal LangChain HITLRequest payload for one action.""" + action: dict[str, Any] = {"name": name, "args": {}} + if tool_call_id is not None: + action["tool_call_id"] = tool_call_id + return { + "action_requests": [action], + "review_configs": [{"action_name": name, "allowed_decisions": ["approve"]}], + } + + +def test_empty_state_has_no_pending_interrupts() -> None: + state = _State() + assert list_pending_interrupts(state) == [] + assert first_pending_interrupt(state) is None + + +def test_single_pending_interrupt_in_task_is_returned() -> None: + state = _State( + tasks=( + _Task( + id="task_1", + interrupts=(_Interrupt(value=_hitl("send_email"), id="int_1"),), + ), + ) + ) + pending = list_pending_interrupts(state) + assert len(pending) == 1 + assert pending[0] == PendingInterrupt( + interrupt_id="int_1", + value=_hitl("send_email"), + source_task_id="task_1", + ) + + +def test_pending_interrupts_returned_in_task_then_root_order() -> None: + """Determinism matters: callers iterate in this order to render the UI.""" + state = _State( + tasks=( + _Task( + id="task_a", + interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),), + ), + _Task( + id="task_b", + interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),), + ), + ), + interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),), + ) + pending = list_pending_interrupts(state) + ids = [p.interrupt_id for p in pending] + assert ids == ["int_a", "int_b", "int_c"] + + +def test_get_by_id_finds_the_right_interrupt_under_parallel_load() -> None: + """Replacing first-wins: id-aware lookup MUST pick the requested one.""" + state = _State( + tasks=( + _Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)), + _Task(interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),)), + _Task(interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),)), + ) + ) + found = get_pending_interrupt_by_id(state, "int_b") + assert found is not None + assert found.value["action_requests"][0]["name"] == "b" + + +def test_get_by_id_returns_none_when_id_is_not_pending() -> None: + state = _State( + tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)),) + ) + assert get_pending_interrupt_by_id(state, "missing") is None + + +def test_get_by_tool_call_id_matches_action_request_payload() -> None: + """HITLRequest carries ``tool_call_id`` per action; lookup uses that.""" + state = _State( + tasks=( + _Task( + interrupts=( + _Interrupt( + value=_hitl("a", tool_call_id="call_xxx"), id="int_a" + ), + _Interrupt( + value=_hitl("b", tool_call_id="call_yyy"), id="int_b" + ), + ) + ), + ) + ) + found = get_pending_interrupt_for_tool_call(state, "call_yyy") + assert found is not None + assert found.interrupt_id == "int_b" + + +def test_first_pending_interrupt_matches_legacy_first_wins_behaviour() -> None: + """Sequential-turn safety: the explicit shortcut still returns the first.""" + state = _State( + tasks=(_Task(interrupts=(_Interrupt(value=_hitl("first"), id="int_1"),)),), + interrupts=(_Interrupt(value=_hitl("second"), id="int_2"),), + ) + first = first_pending_interrupt(state) + assert first is not None + assert first.interrupt_id == "int_1" + + +def test_interrupt_without_id_falls_back_to_none() -> None: + """Snapshots from older LangGraph versions may omit ``id`` — preserve that.""" + state = _State( + tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),) + ) + pending = list_pending_interrupts(state) + assert len(pending) == 1 + assert pending[0].interrupt_id is None + + +def test_non_dict_interrupt_values_are_ignored() -> None: + """Defensive: a non-dict value should not crash the iteration.""" + + class _Raw: + value = "not a dict" + + state = _State(tasks=(_Task(interrupts=(_Raw(),)),)) # type: ignore[arg-type] + assert list_pending_interrupts(state) == [] diff --git a/surfsense_backend/tests/unit/services/streaming/test_interrupt_events.py b/surfsense_backend/tests/unit/services/streaming/test_interrupt_events.py new file mode 100644 index 000000000..dbdd607bf --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_interrupt_events.py @@ -0,0 +1,91 @@ +"""Pin interrupt-payload normalisation and the optional correlation fields on the wire.""" + +from __future__ import annotations + +import json + +import pytest + +from app.services.streaming.events.interrupt import ( + format_interrupt_request, + normalize_interrupt_payload, +) + +pytestmark = pytest.mark.unit + + +def _decode(frame: str) -> dict: + body = frame.removeprefix("data: ").removesuffix("\n\n") + return json.loads(body) + + +def test_hitlrequest_shape_is_passed_through_unchanged() -> None: + raw = { + "action_requests": [{"name": "send_email", "args": {"to": "a@b"}}], + "review_configs": [ + {"action_name": "send_email", "allowed_decisions": ["approve"]} + ], + } + assert normalize_interrupt_payload(raw) == raw + + +def test_custom_interrupt_primitive_is_converted_to_canonical_shape() -> None: + raw = { + "type": "permission", + "message": "Allow send?", + "action": {"tool": "send_email", "params": {"to": "a@b"}}, + "context": {"reason": "destructive"}, + } + out = normalize_interrupt_payload(raw) + assert out["action_requests"] == [ + {"name": "send_email", "args": {"to": "a@b"}} + ] + assert out["review_configs"] == [ + { + "action_name": "send_email", + "allowed_decisions": ["approve", "edit", "reject"], + } + ] + assert out["interrupt_type"] == "permission" + assert out["message"] == "Allow send?" + assert out["context"] == {"reason": "destructive"} + + +def test_custom_interrupt_without_message_omits_message_key() -> None: + """Optional fields stay optional on the wire; FE does not see ``"message": None``.""" + raw = {"action": {"tool": "send_email"}} + out = normalize_interrupt_payload(raw) + assert "message" not in out + + +def test_custom_interrupt_without_tool_falls_back_to_unknown_tool() -> None: + """Defensive: a malformed ``action`` block must not crash the relay.""" + out = normalize_interrupt_payload({"type": "x", "action": {}}) + assert out["action_requests"][0]["name"] == "unknown_tool" + assert out["review_configs"][0]["action_name"] == "unknown_tool" + + +def test_format_interrupt_request_carries_correlation_fields_on_the_wire() -> None: + frame = format_interrupt_request( + {"action_requests": [], "review_configs": []}, + interrupt_id="int_42", + pending_interrupt_count=3, + chat_turn_id="turn_99", + ) + payload = _decode(frame) + assert payload["type"] == "data-interrupt-request" + inner = payload["data"] + assert inner["interrupt_id"] == "int_42" + assert inner["pending_interrupt_count"] == 3 + assert inner["chat_turn_id"] == "turn_99" + + +def test_format_interrupt_request_omits_correlation_fields_when_unset() -> None: + """Backward compat: legacy single-interrupt callers don't have to supply ids.""" + frame = format_interrupt_request( + {"action_requests": [], "review_configs": []}, + ) + inner = _decode(frame)["data"] + assert "interrupt_id" not in inner + assert "pending_interrupt_count" not in inner + assert "chat_turn_id" not in inner diff --git a/surfsense_backend/tests/unit/services/streaming/test_service_emitter_propagation.py b/surfsense_backend/tests/unit/services/streaming/test_service_emitter_propagation.py new file mode 100644 index 000000000..b381f13bc --- /dev/null +++ b/surfsense_backend/tests/unit/services/streaming/test_service_emitter_propagation.py @@ -0,0 +1,142 @@ +"""Pin that sub-agent emitter reaches every wire event the relay emits.""" + +from __future__ import annotations + +import json + +import pytest + +from app.services.streaming.emitter import subagent_emitter +from app.services.streaming.service import StreamingService + +pytestmark = pytest.mark.unit + + +def _decode(frame: str) -> dict: + body = frame.removeprefix("data: ").removesuffix("\n\n") + return json.loads(body) + + +@pytest.fixture +def service() -> StreamingService: + return StreamingService() + + +@pytest.fixture +def sub_emitter(): + return subagent_emitter( + subagent_type="deliverables", + subagent_run_id="sub_xyz", + parent_tool_call_id="call_parent", + ) + + +def test_text_delta_carries_subagent_emitter_on_the_wire(service, sub_emitter) -> None: + payload = _decode(service.format_text_delta("text_1", "hi", emitter=sub_emitter)) + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + assert payload["delta"] == "hi" + + +def test_reasoning_delta_carries_subagent_emitter_on_the_wire( + service, sub_emitter +) -> None: + payload = _decode( + service.format_reasoning_delta("r_1", "thinking", emitter=sub_emitter) + ) + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + + +def test_tool_input_start_carries_subagent_emitter_and_lc_id( + service, sub_emitter +) -> None: + payload = _decode( + service.format_tool_input_start( + "call_1", + "send_email", + langchain_tool_call_id="lc_1", + emitter=sub_emitter, + ) + ) + assert payload["emitted_by"]["subagent_type"] == "deliverables" + assert payload["langchainToolCallId"] == "lc_1" + assert payload["toolName"] == "send_email" + + +def test_tool_output_available_carries_subagent_emitter(service, sub_emitter) -> None: + payload = _decode( + service.format_tool_output_available( + "call_1", {"ok": True}, emitter=sub_emitter + ) + ) + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + assert payload["output"] == {"ok": True} + + +def test_thinking_step_carries_subagent_emitter(service, sub_emitter) -> None: + payload = _decode( + service.format_thinking_step( + step_id="s1", + title="Sending email", + status="in_progress", + emitter=sub_emitter, + ) + ) + assert payload["type"] == "data-thinking-step" + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + + +def test_action_log_carries_subagent_emitter(service, sub_emitter) -> None: + payload = _decode( + service.format_action_log( + {"id": 1, "tool_name": "send_email", "reversible": False}, + emitter=sub_emitter, + ) + ) + assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz" + assert payload["data"]["tool_name"] == "send_email" + + +def test_subagent_lifecycle_events_share_run_id_for_pairing( + service, sub_emitter +) -> None: + start = _decode( + service.format_subagent_start( + subagent_run_id="sub_xyz", + subagent_type="deliverables", + parent_tool_call_id="call_parent", + emitter=sub_emitter, + ) + ) + finish = _decode( + service.format_subagent_finish( + subagent_run_id="sub_xyz", + subagent_type="deliverables", + parent_tool_call_id="call_parent", + emitter=sub_emitter, + ) + ) + assert start["data"]["subagent_run_id"] == finish["data"]["subagent_run_id"] + assert start["type"] == "data-subagent-start" + assert finish["type"] == "data-subagent-finish" + + +def test_main_emitter_events_omit_emitted_by_field(service) -> None: + payload = _decode(service.format_text_delta("text_1", "hi")) + assert "emitted_by" not in payload + + +def test_resolve_emitter_through_service_uses_registry(service, sub_emitter) -> None: + service.emitter_registry.register("run_task_1", sub_emitter) + resolved = service.resolve_emitter( + run_id="run_chat_model", + parent_ids=["root", "run_task_1"], + ) + assert resolved is sub_emitter + + +def test_message_id_is_assigned_on_message_start_and_reused(service) -> None: + frame = service.format_message_start() + payload = _decode(frame) + assigned = payload["messageId"] + assert assigned.startswith("msg_") + assert service.message_id == assigned