Add unit tests for streaming interrupts and service propagation.

This commit is contained in:
CREDO23 2026-05-06 20:08:48 +02:00
parent 619a8362b7
commit 366122da6e
4 changed files with 397 additions and 0 deletions

View file

@ -0,0 +1,164 @@
"""Pin id-aware pending-interrupt lookup that replaces the buggy first-wins."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import pytest
from app.services.streaming.interrupt_correlation import (
PendingInterrupt,
first_pending_interrupt,
get_pending_interrupt_by_id,
get_pending_interrupt_for_tool_call,
list_pending_interrupts,
)
pytestmark = pytest.mark.unit
@dataclass
class _Interrupt:
value: dict[str, Any]
id: str | None = None
@dataclass
class _Task:
interrupts: tuple[_Interrupt, ...] = ()
id: str | None = None
@dataclass
class _State:
tasks: tuple[_Task, ...] = ()
interrupts: tuple[_Interrupt, ...] = ()
def _hitl(name: str, tool_call_id: str | None = None) -> dict[str, Any]:
"""Minimal LangChain HITLRequest payload for one action."""
action: dict[str, Any] = {"name": name, "args": {}}
if tool_call_id is not None:
action["tool_call_id"] = tool_call_id
return {
"action_requests": [action],
"review_configs": [{"action_name": name, "allowed_decisions": ["approve"]}],
}
def test_empty_state_has_no_pending_interrupts() -> None:
state = _State()
assert list_pending_interrupts(state) == []
assert first_pending_interrupt(state) is None
def test_single_pending_interrupt_in_task_is_returned() -> None:
state = _State(
tasks=(
_Task(
id="task_1",
interrupts=(_Interrupt(value=_hitl("send_email"), id="int_1"),),
),
)
)
pending = list_pending_interrupts(state)
assert len(pending) == 1
assert pending[0] == PendingInterrupt(
interrupt_id="int_1",
value=_hitl("send_email"),
source_task_id="task_1",
)
def test_pending_interrupts_returned_in_task_then_root_order() -> None:
"""Determinism matters: callers iterate in this order to render the UI."""
state = _State(
tasks=(
_Task(
id="task_a",
interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),),
),
_Task(
id="task_b",
interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),),
),
),
interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),),
)
pending = list_pending_interrupts(state)
ids = [p.interrupt_id for p in pending]
assert ids == ["int_a", "int_b", "int_c"]
def test_get_by_id_finds_the_right_interrupt_under_parallel_load() -> None:
"""Replacing first-wins: id-aware lookup MUST pick the requested one."""
state = _State(
tasks=(
_Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)),
_Task(interrupts=(_Interrupt(value=_hitl("b"), id="int_b"),)),
_Task(interrupts=(_Interrupt(value=_hitl("c"), id="int_c"),)),
)
)
found = get_pending_interrupt_by_id(state, "int_b")
assert found is not None
assert found.value["action_requests"][0]["name"] == "b"
def test_get_by_id_returns_none_when_id_is_not_pending() -> None:
state = _State(
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id="int_a"),)),)
)
assert get_pending_interrupt_by_id(state, "missing") is None
def test_get_by_tool_call_id_matches_action_request_payload() -> None:
"""HITLRequest carries ``tool_call_id`` per action; lookup uses that."""
state = _State(
tasks=(
_Task(
interrupts=(
_Interrupt(
value=_hitl("a", tool_call_id="call_xxx"), id="int_a"
),
_Interrupt(
value=_hitl("b", tool_call_id="call_yyy"), id="int_b"
),
)
),
)
)
found = get_pending_interrupt_for_tool_call(state, "call_yyy")
assert found is not None
assert found.interrupt_id == "int_b"
def test_first_pending_interrupt_matches_legacy_first_wins_behaviour() -> None:
"""Sequential-turn safety: the explicit shortcut still returns the first."""
state = _State(
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("first"), id="int_1"),)),),
interrupts=(_Interrupt(value=_hitl("second"), id="int_2"),),
)
first = first_pending_interrupt(state)
assert first is not None
assert first.interrupt_id == "int_1"
def test_interrupt_without_id_falls_back_to_none() -> None:
"""Snapshots from older LangGraph versions may omit ``id`` — preserve that."""
state = _State(
tasks=(_Task(interrupts=(_Interrupt(value=_hitl("a"), id=None),)),)
)
pending = list_pending_interrupts(state)
assert len(pending) == 1
assert pending[0].interrupt_id is None
def test_non_dict_interrupt_values_are_ignored() -> None:
"""Defensive: a non-dict value should not crash the iteration."""
class _Raw:
value = "not a dict"
state = _State(tasks=(_Task(interrupts=(_Raw(),)),)) # type: ignore[arg-type]
assert list_pending_interrupts(state) == []

View file

@ -0,0 +1,91 @@
"""Pin interrupt-payload normalisation and the optional correlation fields on the wire."""
from __future__ import annotations
import json
import pytest
from app.services.streaming.events.interrupt import (
format_interrupt_request,
normalize_interrupt_payload,
)
pytestmark = pytest.mark.unit
def _decode(frame: str) -> dict:
body = frame.removeprefix("data: ").removesuffix("\n\n")
return json.loads(body)
def test_hitlrequest_shape_is_passed_through_unchanged() -> None:
raw = {
"action_requests": [{"name": "send_email", "args": {"to": "a@b"}}],
"review_configs": [
{"action_name": "send_email", "allowed_decisions": ["approve"]}
],
}
assert normalize_interrupt_payload(raw) == raw
def test_custom_interrupt_primitive_is_converted_to_canonical_shape() -> None:
raw = {
"type": "permission",
"message": "Allow send?",
"action": {"tool": "send_email", "params": {"to": "a@b"}},
"context": {"reason": "destructive"},
}
out = normalize_interrupt_payload(raw)
assert out["action_requests"] == [
{"name": "send_email", "args": {"to": "a@b"}}
]
assert out["review_configs"] == [
{
"action_name": "send_email",
"allowed_decisions": ["approve", "edit", "reject"],
}
]
assert out["interrupt_type"] == "permission"
assert out["message"] == "Allow send?"
assert out["context"] == {"reason": "destructive"}
def test_custom_interrupt_without_message_omits_message_key() -> None:
"""Optional fields stay optional on the wire; FE does not see ``"message": None``."""
raw = {"action": {"tool": "send_email"}}
out = normalize_interrupt_payload(raw)
assert "message" not in out
def test_custom_interrupt_without_tool_falls_back_to_unknown_tool() -> None:
"""Defensive: a malformed ``action`` block must not crash the relay."""
out = normalize_interrupt_payload({"type": "x", "action": {}})
assert out["action_requests"][0]["name"] == "unknown_tool"
assert out["review_configs"][0]["action_name"] == "unknown_tool"
def test_format_interrupt_request_carries_correlation_fields_on_the_wire() -> None:
frame = format_interrupt_request(
{"action_requests": [], "review_configs": []},
interrupt_id="int_42",
pending_interrupt_count=3,
chat_turn_id="turn_99",
)
payload = _decode(frame)
assert payload["type"] == "data-interrupt-request"
inner = payload["data"]
assert inner["interrupt_id"] == "int_42"
assert inner["pending_interrupt_count"] == 3
assert inner["chat_turn_id"] == "turn_99"
def test_format_interrupt_request_omits_correlation_fields_when_unset() -> None:
"""Backward compat: legacy single-interrupt callers don't have to supply ids."""
frame = format_interrupt_request(
{"action_requests": [], "review_configs": []},
)
inner = _decode(frame)["data"]
assert "interrupt_id" not in inner
assert "pending_interrupt_count" not in inner
assert "chat_turn_id" not in inner

View file

@ -0,0 +1,142 @@
"""Pin that sub-agent emitter reaches every wire event the relay emits."""
from __future__ import annotations
import json
import pytest
from app.services.streaming.emitter import subagent_emitter
from app.services.streaming.service import StreamingService
pytestmark = pytest.mark.unit
def _decode(frame: str) -> dict:
body = frame.removeprefix("data: ").removesuffix("\n\n")
return json.loads(body)
@pytest.fixture
def service() -> StreamingService:
return StreamingService()
@pytest.fixture
def sub_emitter():
return subagent_emitter(
subagent_type="deliverables",
subagent_run_id="sub_xyz",
parent_tool_call_id="call_parent",
)
def test_text_delta_carries_subagent_emitter_on_the_wire(service, sub_emitter) -> None:
payload = _decode(service.format_text_delta("text_1", "hi", emitter=sub_emitter))
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
assert payload["delta"] == "hi"
def test_reasoning_delta_carries_subagent_emitter_on_the_wire(
service, sub_emitter
) -> None:
payload = _decode(
service.format_reasoning_delta("r_1", "thinking", emitter=sub_emitter)
)
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
def test_tool_input_start_carries_subagent_emitter_and_lc_id(
service, sub_emitter
) -> None:
payload = _decode(
service.format_tool_input_start(
"call_1",
"send_email",
langchain_tool_call_id="lc_1",
emitter=sub_emitter,
)
)
assert payload["emitted_by"]["subagent_type"] == "deliverables"
assert payload["langchainToolCallId"] == "lc_1"
assert payload["toolName"] == "send_email"
def test_tool_output_available_carries_subagent_emitter(service, sub_emitter) -> None:
payload = _decode(
service.format_tool_output_available(
"call_1", {"ok": True}, emitter=sub_emitter
)
)
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
assert payload["output"] == {"ok": True}
def test_thinking_step_carries_subagent_emitter(service, sub_emitter) -> None:
payload = _decode(
service.format_thinking_step(
step_id="s1",
title="Sending email",
status="in_progress",
emitter=sub_emitter,
)
)
assert payload["type"] == "data-thinking-step"
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
def test_action_log_carries_subagent_emitter(service, sub_emitter) -> None:
payload = _decode(
service.format_action_log(
{"id": 1, "tool_name": "send_email", "reversible": False},
emitter=sub_emitter,
)
)
assert payload["emitted_by"]["subagent_run_id"] == "sub_xyz"
assert payload["data"]["tool_name"] == "send_email"
def test_subagent_lifecycle_events_share_run_id_for_pairing(
service, sub_emitter
) -> None:
start = _decode(
service.format_subagent_start(
subagent_run_id="sub_xyz",
subagent_type="deliverables",
parent_tool_call_id="call_parent",
emitter=sub_emitter,
)
)
finish = _decode(
service.format_subagent_finish(
subagent_run_id="sub_xyz",
subagent_type="deliverables",
parent_tool_call_id="call_parent",
emitter=sub_emitter,
)
)
assert start["data"]["subagent_run_id"] == finish["data"]["subagent_run_id"]
assert start["type"] == "data-subagent-start"
assert finish["type"] == "data-subagent-finish"
def test_main_emitter_events_omit_emitted_by_field(service) -> None:
payload = _decode(service.format_text_delta("text_1", "hi"))
assert "emitted_by" not in payload
def test_resolve_emitter_through_service_uses_registry(service, sub_emitter) -> None:
service.emitter_registry.register("run_task_1", sub_emitter)
resolved = service.resolve_emitter(
run_id="run_chat_model",
parent_ids=["root", "run_task_1"],
)
assert resolved is sub_emitter
def test_message_id_is_assigned_on_message_start_and_reused(service) -> None:
frame = service.format_message_start()
payload = _decode(frame)
assigned = payload["messageId"]
assert assigned.startswith("msg_")
assert service.message_id == assigned