mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
chat/stream_new_chat: emit one SSE frame per pending interrupt
This commit is contained in:
parent
583ac83735
commit
c06dd6e8ba
5 changed files with 259 additions and 118 deletions
|
|
@ -76,6 +76,9 @@ from app.services.chat_session_state_service import (
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
|
from app.tasks.chat.streaming.graph_stream.event_stream import stream_output
|
||||||
|
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
|
||||||
|
all_interrupt_values,
|
||||||
|
)
|
||||||
from app.utils.content_utils import bootstrap_history_from_db
|
from app.utils.content_utils import bootstrap_history_from_db
|
||||||
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
|
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
|
||||||
from app.utils.user_message_multimodal import build_human_message_content
|
from app.utils.user_message_multimodal import build_human_message_content
|
||||||
|
|
@ -98,47 +101,6 @@ def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||||
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||||
|
|
||||||
|
|
||||||
def _first_interrupt_value(state: Any) -> dict[str, Any] | None:
|
|
||||||
"""Return the first LangGraph interrupt payload across all snapshot tasks."""
|
|
||||||
|
|
||||||
def _extract_interrupt_value(candidate: Any) -> dict[str, Any] | None:
|
|
||||||
if isinstance(candidate, dict):
|
|
||||||
value = candidate.get("value", candidate)
|
|
||||||
return value if isinstance(value, dict) else None
|
|
||||||
value = getattr(candidate, "value", None)
|
|
||||||
if isinstance(value, dict):
|
|
||||||
return value
|
|
||||||
if isinstance(candidate, (list, tuple)):
|
|
||||||
for item in candidate:
|
|
||||||
extracted = _extract_interrupt_value(item)
|
|
||||||
if extracted is not None:
|
|
||||||
return extracted
|
|
||||||
return None
|
|
||||||
|
|
||||||
for task in getattr(state, "tasks", ()) or ():
|
|
||||||
try:
|
|
||||||
interrupts = getattr(task, "interrupts", ()) or ()
|
|
||||||
except (AttributeError, IndexError, TypeError):
|
|
||||||
interrupts = ()
|
|
||||||
if not interrupts:
|
|
||||||
extracted = _extract_interrupt_value(task)
|
|
||||||
if extracted is not None:
|
|
||||||
return extracted
|
|
||||||
continue
|
|
||||||
for interrupt_item in interrupts:
|
|
||||||
extracted = _extract_interrupt_value(interrupt_item)
|
|
||||||
if extracted is not None:
|
|
||||||
return extracted
|
|
||||||
try:
|
|
||||||
state_interrupts = getattr(state, "interrupts", ()) or ()
|
|
||||||
except (AttributeError, IndexError, TypeError):
|
|
||||||
state_interrupts = ()
|
|
||||||
extracted = _extract_interrupt_value(state_interrupts)
|
|
||||||
if extracted is not None:
|
|
||||||
return extracted
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
def _extract_chunk_parts(chunk: Any) -> dict[str, Any]:
|
||||||
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
|
"""Decompose an ``AIMessageChunk`` into typed text/reasoning/tool-call parts.
|
||||||
|
|
||||||
|
|
@ -301,7 +263,6 @@ def extract_todos_from_deepagents(command_output) -> dict:
|
||||||
class StreamResult:
|
class StreamResult:
|
||||||
accumulated_text: str = ""
|
accumulated_text: str = ""
|
||||||
is_interrupted: bool = False
|
is_interrupted: bool = False
|
||||||
interrupt_value: dict[str, Any] | None = None
|
|
||||||
sandbox_files: list[str] = field(default_factory=list)
|
sandbox_files: list[str] = field(default_factory=list)
|
||||||
agent_called_update_memory: bool = False
|
agent_called_update_memory: bool = False
|
||||||
request_id: str | None = None
|
request_id: str | None = None
|
||||||
|
|
@ -915,11 +876,15 @@ async def _stream_agent_events(
|
||||||
result.accumulated_text = accumulated_text
|
result.accumulated_text = accumulated_text
|
||||||
_log_file_contract("turn_outcome", result)
|
_log_file_contract("turn_outcome", result)
|
||||||
|
|
||||||
interrupt_value = _first_interrupt_value(state)
|
pending_values = all_interrupt_values(state)
|
||||||
if interrupt_value is not None:
|
if pending_values:
|
||||||
result.is_interrupted = True
|
result.is_interrupted = True
|
||||||
result.interrupt_value = interrupt_value
|
# One frame per paused subagent so each parallel HITL renders its own
|
||||||
yield streaming_service.format_interrupt_request(result.interrupt_value)
|
# approval card on the wire. Order matches ``state.interrupts``, which
|
||||||
|
# the resume slicer in ``checkpointed_subagent_middleware.resume_routing``
|
||||||
|
# consumes in the same order — keeping emit and resume in lock-step.
|
||||||
|
for interrupt_value in pending_values:
|
||||||
|
yield streaming_service.format_interrupt_request(interrupt_value)
|
||||||
|
|
||||||
|
|
||||||
async def stream_new_chat(
|
async def stream_new_chat(
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ from typing import Any
|
||||||
class StreamingResult:
|
class StreamingResult:
|
||||||
accumulated_text: str = ""
|
accumulated_text: str = ""
|
||||||
is_interrupted: bool = False
|
is_interrupted: bool = False
|
||||||
interrupt_value: dict[str, Any] | None = None
|
|
||||||
sandbox_files: list[str] = field(default_factory=list)
|
sandbox_files: list[str] = field(default_factory=list)
|
||||||
agent_called_update_memory: bool = False
|
agent_called_update_memory: bool = False
|
||||||
request_id: str | None = None
|
request_id: str | None = None
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,30 @@
|
||||||
"""Read the first interrupt payload from a LangGraph state snapshot."""
|
"""Read every pending interrupt payload from a LangGraph state snapshot.
|
||||||
|
|
||||||
|
The chat-stream emit loop yields one ``data-interrupt-request`` SSE frame per
|
||||||
|
pending interrupt so parallel HITL across siblings stays addressable on the
|
||||||
|
wire (the resume slicer in ``checkpointed_subagent_middleware.resume_routing``
|
||||||
|
correlates each frame back to the right paused subagent via the stamped
|
||||||
|
``tool_call_id``). This helper produces that flat, ordered list.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def first_interrupt_value(state: Any) -> dict[str, Any] | None:
|
def all_interrupt_values(state: Any) -> list[dict[str, Any]]:
|
||||||
"""Return the first interrupt payload across all snapshot tasks."""
|
"""Return every interrupt payload across the snapshot, in traversal order.
|
||||||
|
|
||||||
|
Walks ``state.tasks[*].interrupts`` first (langgraph's per-task buckets,
|
||||||
|
which carry one interrupt per paused subagent) and falls back to
|
||||||
|
``state.interrupts`` when the per-task lists are empty. Order matches the
|
||||||
|
snapshot's iteration order so the emit-time order on the SSE stream agrees
|
||||||
|
with ``collect_pending_tool_calls`` consumption order on resume.
|
||||||
|
|
||||||
|
Defensive against malformed snapshots: tasks/interrupts that raise on
|
||||||
|
attribute access are skipped silently. Non-dict values are skipped — the
|
||||||
|
chat-stream contract requires structured interrupt payloads.
|
||||||
|
"""
|
||||||
|
|
||||||
def _extract(candidate: Any) -> dict[str, Any] | None:
|
def _extract(candidate: Any) -> dict[str, Any] | None:
|
||||||
if isinstance(candidate, dict):
|
if isinstance(candidate, dict):
|
||||||
|
|
@ -15,33 +33,32 @@ def first_interrupt_value(state: Any) -> dict[str, Any] | None:
|
||||||
value = getattr(candidate, "value", None)
|
value = getattr(candidate, "value", None)
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
return value
|
return value
|
||||||
if isinstance(candidate, list | tuple):
|
|
||||||
for item in candidate:
|
|
||||||
extracted = _extract(item)
|
|
||||||
if extracted is not None:
|
|
||||||
return extracted
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
values: list[dict[str, Any]] = []
|
||||||
|
saw_task_interrupt = False
|
||||||
|
|
||||||
for task in getattr(state, "tasks", ()) or ():
|
for task in getattr(state, "tasks", ()) or ():
|
||||||
try:
|
try:
|
||||||
interrupts = getattr(task, "interrupts", ()) or ()
|
interrupts = getattr(task, "interrupts", ()) or ()
|
||||||
except (AttributeError, IndexError, TypeError):
|
except (AttributeError, IndexError, TypeError):
|
||||||
interrupts = ()
|
interrupts = ()
|
||||||
if not interrupts:
|
if interrupts:
|
||||||
extracted = _extract(task)
|
saw_task_interrupt = True
|
||||||
if extracted is not None:
|
for interrupt_item in interrupts:
|
||||||
return extracted
|
extracted = _extract(interrupt_item)
|
||||||
continue
|
if extracted is not None:
|
||||||
for interrupt_item in interrupts:
|
values.append(extracted)
|
||||||
extracted = _extract(interrupt_item)
|
|
||||||
if extracted is not None:
|
if saw_task_interrupt:
|
||||||
return extracted
|
return values
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state_interrupts = getattr(state, "interrupts", ()) or ()
|
state_interrupts = getattr(state, "interrupts", ()) or ()
|
||||||
except (AttributeError, IndexError, TypeError):
|
except (AttributeError, IndexError, TypeError):
|
||||||
state_interrupts = ()
|
state_interrupts = ()
|
||||||
extracted = _extract(state_interrupts)
|
for interrupt_item in state_interrupts:
|
||||||
if extracted is not None:
|
extracted = _extract(interrupt_item)
|
||||||
return extracted
|
if extracted is not None:
|
||||||
return None
|
values.append(extracted)
|
||||||
|
return values
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,210 @@
|
||||||
|
"""Real-graph contract: ``all_interrupt_values`` surfaces every pending interrupt.
|
||||||
|
|
||||||
|
The chat-stream emit loop must yield one ``data-interrupt-request`` SSE frame
|
||||||
|
per paused subagent, in the same order ``state.interrupts`` reports them —
|
||||||
|
that's also the order the resume slicer consumes decisions. These tests pin
|
||||||
|
that contract against a **real** paused parent graph built via
|
||||||
|
:class:`~langgraph.types.Send` fan-out (no synthetic state mocks).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain.tools import ToolRuntime
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
from langgraph.types import Send, interrupt
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||||
|
build_task_tool_with_parent_config,
|
||||||
|
)
|
||||||
|
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
|
||||||
|
all_interrupt_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _SubState(TypedDict, total=False):
|
||||||
|
messages: list
|
||||||
|
|
||||||
|
|
||||||
|
class _DispatchState(TypedDict, total=False):
|
||||||
|
messages: list
|
||||||
|
tcid: str
|
||||||
|
desc: str
|
||||||
|
|
||||||
|
|
||||||
|
def _build_pausing_subagent(checkpointer: InMemorySaver):
|
||||||
|
def approve_node(_state):
|
||||||
|
decision = interrupt(
|
||||||
|
{
|
||||||
|
"action_requests": [
|
||||||
|
{"name": "do_thing", "args": {"x": 1}, "description": ""}
|
||||||
|
],
|
||||||
|
"review_configs": [{}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {"messages": [AIMessage(content=f"got:{decision}")]}
|
||||||
|
|
||||||
|
g = StateGraph(_SubState)
|
||||||
|
g.add_node("approve", approve_node)
|
||||||
|
g.add_edge(START, "approve")
|
||||||
|
g.add_edge("approve", END)
|
||||||
|
return g.compile(checkpointer=checkpointer)
|
||||||
|
|
||||||
|
|
||||||
|
def _parent_graph_dispatching_two_tasks_via_send(
|
||||||
|
task_tool, *, tool_call_id_a: str, tool_call_id_b: str, checkpointer
|
||||||
|
):
|
||||||
|
def fanout_edge(_state) -> list[Send]:
|
||||||
|
return [
|
||||||
|
Send("call_task", {"tcid": tool_call_id_a, "desc": "approve A"}),
|
||||||
|
Send("call_task", {"tcid": tool_call_id_b, "desc": "approve B"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||||
|
rt = ToolRuntime(
|
||||||
|
state=state,
|
||||||
|
config=config,
|
||||||
|
context=None,
|
||||||
|
stream_writer=None,
|
||||||
|
tool_call_id=state["tcid"],
|
||||||
|
store=None,
|
||||||
|
)
|
||||||
|
return await task_tool.coroutine(
|
||||||
|
description=state["desc"], subagent_type="approver", runtime=rt
|
||||||
|
)
|
||||||
|
|
||||||
|
g = StateGraph(_DispatchState)
|
||||||
|
g.add_node("call_task", call_task)
|
||||||
|
g.add_conditional_edges(START, fanout_edge, ["call_task"])
|
||||||
|
g.add_edge("call_task", END)
|
||||||
|
return g.compile(checkpointer=checkpointer)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_every_pending_interrupt_for_two_paused_subagents():
|
||||||
|
"""Two parallel subagents -> ``all_interrupt_values`` returns 2 dicts."""
|
||||||
|
checkpointer = InMemorySaver()
|
||||||
|
subagent = _build_pausing_subagent(checkpointer)
|
||||||
|
task_tool = build_task_tool_with_parent_config(
|
||||||
|
[{"name": "approver", "description": "approves", "runnable": subagent}]
|
||||||
|
)
|
||||||
|
parent = _parent_graph_dispatching_two_tasks_via_send(
|
||||||
|
task_tool,
|
||||||
|
tool_call_id_a="parent-tcid-A",
|
||||||
|
tool_call_id_b="parent-tcid-B",
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
)
|
||||||
|
|
||||||
|
parent_config = {
|
||||||
|
"configurable": {"thread_id": "all-iv-thread"},
|
||||||
|
"recursion_limit": 100,
|
||||||
|
}
|
||||||
|
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||||
|
state = await parent.aget_state(parent_config)
|
||||||
|
|
||||||
|
values = all_interrupt_values(state)
|
||||||
|
|
||||||
|
assert isinstance(values, list)
|
||||||
|
assert len(values) == 2, (
|
||||||
|
f"REGRESSION: expected one value per pending subagent, got "
|
||||||
|
f"{len(values)}: {values!r}"
|
||||||
|
)
|
||||||
|
stamps = [v.get("tool_call_id") for v in values]
|
||||||
|
assert sorted(stamps) == ["parent-tcid-A", "parent-tcid-B"]
|
||||||
|
for v in values:
|
||||||
|
assert isinstance(v.get("action_requests"), list)
|
||||||
|
assert len(v["action_requests"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserves_state_interrupts_traversal_order():
|
||||||
|
"""Order returned by inspector must match ``state.interrupts`` order.
|
||||||
|
|
||||||
|
The resume slicer consumes decisions left-to-right against
|
||||||
|
``collect_pending_tool_calls(state)`` which walks ``state.interrupts``
|
||||||
|
in iteration order — so the inspector (which drives the *emit* order)
|
||||||
|
must agree with that traversal or the slice and the wire fall out of sync.
|
||||||
|
"""
|
||||||
|
checkpointer = InMemorySaver()
|
||||||
|
subagent = _build_pausing_subagent(checkpointer)
|
||||||
|
task_tool = build_task_tool_with_parent_config(
|
||||||
|
[{"name": "approver", "description": "approves", "runnable": subagent}]
|
||||||
|
)
|
||||||
|
parent = _parent_graph_dispatching_two_tasks_via_send(
|
||||||
|
task_tool,
|
||||||
|
tool_call_id_a="parent-tcid-A",
|
||||||
|
tool_call_id_b="parent-tcid-B",
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
)
|
||||||
|
parent_config = {
|
||||||
|
"configurable": {"thread_id": "order-thread"},
|
||||||
|
"recursion_limit": 100,
|
||||||
|
}
|
||||||
|
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, parent_config)
|
||||||
|
state = await parent.aget_state(parent_config)
|
||||||
|
|
||||||
|
inspector_order = [v["tool_call_id"] for v in all_interrupt_values(state)]
|
||||||
|
state_order = [
|
||||||
|
i.value["tool_call_id"]
|
||||||
|
for i in state.interrupts
|
||||||
|
if isinstance(getattr(i, "value", None), dict)
|
||||||
|
and "tool_call_id" in i.value
|
||||||
|
]
|
||||||
|
|
||||||
|
assert inspector_order == state_order, (
|
||||||
|
f"inspector order {inspector_order!r} diverged from state.interrupts "
|
||||||
|
f"order {state_order!r}; the resume slicer would mis-route decisions."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_empty_list_when_nothing_paused():
|
||||||
|
"""A graph that completes normally produces no interrupts to surface."""
|
||||||
|
|
||||||
|
def done_node(_state):
|
||||||
|
return {"messages": [AIMessage(content="done")]}
|
||||||
|
|
||||||
|
g = StateGraph(_SubState)
|
||||||
|
g.add_node("done", done_node)
|
||||||
|
g.add_edge(START, "done")
|
||||||
|
g.add_edge("done", END)
|
||||||
|
graph = g.compile(checkpointer=InMemorySaver())
|
||||||
|
config = {"configurable": {"thread_id": "no-pause-thread"}}
|
||||||
|
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||||
|
state = await graph.aget_state(config)
|
||||||
|
|
||||||
|
assert all_interrupt_values(state) == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_paused_subagent_returns_a_list_of_one():
|
||||||
|
"""Single-pause case must still return a list (not unwrap to a dict)."""
|
||||||
|
|
||||||
|
def approve_node(_state):
|
||||||
|
decision = interrupt(
|
||||||
|
{
|
||||||
|
"action_requests": [{"name": "x", "args": {}, "description": ""}],
|
||||||
|
"review_configs": [{}],
|
||||||
|
"tool_call_id": "lonely-tcid",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {"messages": [AIMessage(content=f"got:{decision}")]}
|
||||||
|
|
||||||
|
g = StateGraph(_SubState)
|
||||||
|
g.add_node("approve", approve_node)
|
||||||
|
g.add_edge(START, "approve")
|
||||||
|
g.add_edge("approve", END)
|
||||||
|
graph = g.compile(checkpointer=InMemorySaver())
|
||||||
|
config = {"configurable": {"thread_id": "single-thread"}}
|
||||||
|
await graph.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||||
|
state = await graph.aget_state(config)
|
||||||
|
|
||||||
|
values = all_interrupt_values(state)
|
||||||
|
|
||||||
|
assert isinstance(values, list)
|
||||||
|
assert len(values) == 1
|
||||||
|
assert values[0].get("tool_call_id") == "lonely-tcid"
|
||||||
|
|
@ -23,7 +23,6 @@ from app.tasks.chat.stream_new_chat import (
|
||||||
_emit_stream_terminal_error as old_emit_terminal_error,
|
_emit_stream_terminal_error as old_emit_terminal_error,
|
||||||
_extract_chunk_parts as old_extract_chunk_parts,
|
_extract_chunk_parts as old_extract_chunk_parts,
|
||||||
_extract_resolved_file_path as old_extract_resolved_file_path,
|
_extract_resolved_file_path as old_extract_resolved_file_path,
|
||||||
_first_interrupt_value as old_first_interrupt_value,
|
|
||||||
_tool_output_has_error as old_tool_output_has_error,
|
_tool_output_has_error as old_tool_output_has_error,
|
||||||
_tool_output_to_text as old_tool_output_to_text,
|
_tool_output_to_text as old_tool_output_to_text,
|
||||||
)
|
)
|
||||||
|
|
@ -36,9 +35,6 @@ from app.tasks.chat.streaming.errors.emitter import (
|
||||||
from app.tasks.chat.streaming.helpers.chunk_parts import (
|
from app.tasks.chat.streaming.helpers.chunk_parts import (
|
||||||
extract_chunk_parts as new_extract_chunk_parts,
|
extract_chunk_parts as new_extract_chunk_parts,
|
||||||
)
|
)
|
||||||
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
|
|
||||||
first_interrupt_value as new_first_interrupt_value,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.helpers.tool_output import (
|
from app.tasks.chat.streaming.helpers.tool_output import (
|
||||||
extract_resolved_file_path as new_extract_resolved_file_path,
|
extract_resolved_file_path as new_extract_resolved_file_path,
|
||||||
tool_output_has_error as new_tool_output_has_error,
|
tool_output_has_error as new_tool_output_has_error,
|
||||||
|
|
@ -105,52 +101,6 @@ def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None:
|
||||||
assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk)
|
assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------- interrupt inspector
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class _Interrupt:
|
|
||||||
value: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class _Task:
|
|
||||||
interrupts: tuple[Any, ...] = ()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class _State:
|
|
||||||
tasks: tuple[Any, ...] = ()
|
|
||||||
interrupts: tuple[Any, ...] = ()
|
|
||||||
|
|
||||||
|
|
||||||
_INTERRUPT_CASES: list[Any] = [
|
|
||||||
_State(),
|
|
||||||
_State(tasks=(_Task(interrupts=(_Interrupt(value={"name": "send"}),)),)),
|
|
||||||
# Multiple tasks: must return the FIRST one in iteration order.
|
|
||||||
_State(
|
|
||||||
tasks=(
|
|
||||||
_Task(interrupts=(_Interrupt(value={"name": "first"}),)),
|
|
||||||
_Task(interrupts=(_Interrupt(value={"name": "second"}),)),
|
|
||||||
)
|
|
||||||
),
|
|
||||||
# Empty task interrupts -> falls back to root state.interrupts.
|
|
||||||
_State(
|
|
||||||
tasks=(_Task(interrupts=()),),
|
|
||||||
interrupts=(_Interrupt(value={"name": "root"}),),
|
|
||||||
),
|
|
||||||
# Interrupts as plain dicts (not wrapper objects).
|
|
||||||
_State(interrupts=({"value": {"name": "dict_root"}},)),
|
|
||||||
# A defective task whose `.interrupts` raises - must be tolerated.
|
|
||||||
_State(tasks=(object(),)),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("state", _INTERRUPT_CASES)
|
|
||||||
def test_first_interrupt_value_matches_old_implementation(state: Any) -> None:
|
|
||||||
assert new_first_interrupt_value(state) == old_first_interrupt_value(state)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------- error classifier
|
# ----------------------------------------------------------- error classifier
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue