Add parity unit tests for extracted chat streaming vs legacy.

This commit is contained in:
CREDO23 2026-05-06 20:08:48 +02:00
parent ec26ca69a6
commit 8b6ffd12b8
3 changed files with 532 additions and 0 deletions

View file

@ -0,0 +1,292 @@
"""Pin Stage 1 extractions as faithful copies of the old helpers.
The new orchestrator under ``app.tasks.chat.streaming`` is built in
parallel with the production module ``app.tasks.chat.stream_new_chat``.
For each Stage 1 extraction we assert the new function returns the same
output as the old one for a representative input set. The moment the
two diverge - intentionally or otherwise - this file fails loudly so
the divergence is reviewed rather than shipped silently.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
import pytest
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
from app.tasks.chat.stream_new_chat import (
_classify_stream_exception as old_classify,
_emit_stream_terminal_error as old_emit_terminal_error,
_extract_chunk_parts as old_extract_chunk_parts,
_extract_resolved_file_path as old_extract_resolved_file_path,
_first_interrupt_value as old_first_interrupt_value,
_tool_output_has_error as old_tool_output_has_error,
_tool_output_to_text as old_tool_output_to_text,
)
from app.tasks.chat.streaming.errors.classifier import (
classify_stream_exception as new_classify,
)
from app.tasks.chat.streaming.errors.emitter import (
emit_stream_terminal_error as new_emit_terminal_error,
)
from app.tasks.chat.streaming.helpers.chunk_parts import (
extract_chunk_parts as new_extract_chunk_parts,
)
from app.tasks.chat.streaming.helpers.interrupt_inspector import (
first_interrupt_value as new_first_interrupt_value,
)
from app.tasks.chat.streaming.helpers.tool_output import (
extract_resolved_file_path as new_extract_resolved_file_path,
tool_output_has_error as new_tool_output_has_error,
tool_output_to_text as new_tool_output_to_text,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------- chunk parts
@dataclass
class _Chunk:
content: Any = ""
additional_kwargs: dict[str, Any] = field(default_factory=dict)
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
_CHUNK_CASES: list[Any] = [
None,
_Chunk(content=""),
_Chunk(content="hello"),
_Chunk(content=42), # invalid type, defensively coerced to empty
_Chunk(
content=[
{"type": "text", "text": "Hello "},
{"type": "text", "text": "world"},
]
),
_Chunk(
content=[
{"type": "reasoning", "reasoning": "hmm "},
{"type": "reasoning", "text": "still"},
{"type": "text", "text": "answer"},
]
),
_Chunk(
content=[
{"type": "tool_call_chunk", "id": "c1", "name": "x", "args": "{"},
{"type": "tool_use", "id": "c2", "name": "y"},
{"type": "image_url", "url": "ignored"},
]
),
_Chunk(
content="visible",
additional_kwargs={"reasoning_content": "private"},
),
_Chunk(
tool_call_chunks=[
{"id": None, "name": None, "args": '{"a":1}', "index": 0},
{"id": "c", "name": "n", "args": "}", "index": 0},
]
),
_Chunk(
content=[{"type": "tool_call_chunk", "id": "from-block", "name": "x"}],
tool_call_chunks=[{"id": "from-attr", "name": "y"}],
),
]
@pytest.mark.parametrize("chunk", _CHUNK_CASES)
def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None:
assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk)
# ---------------------------------------------------------- interrupt inspector
@dataclass
class _Interrupt:
value: dict[str, Any]
@dataclass
class _Task:
interrupts: tuple[Any, ...] = ()
@dataclass
class _State:
tasks: tuple[Any, ...] = ()
interrupts: tuple[Any, ...] = ()
_INTERRUPT_CASES: list[Any] = [
_State(),
_State(tasks=(_Task(interrupts=(_Interrupt(value={"name": "send"}),)),)),
# Multiple tasks: must return the FIRST one in iteration order.
_State(
tasks=(
_Task(interrupts=(_Interrupt(value={"name": "first"}),)),
_Task(interrupts=(_Interrupt(value={"name": "second"}),)),
)
),
# Empty task interrupts -> falls back to root state.interrupts.
_State(
tasks=(_Task(interrupts=()),),
interrupts=(_Interrupt(value={"name": "root"}),),
),
# Interrupts as plain dicts (not wrapper objects).
_State(interrupts=({"value": {"name": "dict_root"}},)),
# A defective task whose `.interrupts` raises - must be tolerated.
_State(tasks=(object(),)),
]
@pytest.mark.parametrize("state", _INTERRUPT_CASES)
def test_first_interrupt_value_matches_old_implementation(state: Any) -> None:
assert new_first_interrupt_value(state) == old_first_interrupt_value(state)
# ----------------------------------------------------------- error classifier
def _classify_cases() -> list[Exception]:
"""Inputs that the FE depends on being mapped to specific error codes."""
return [
Exception("totally generic error"),
Exception(
'{"error":{"type":"rate_limit_error","message":"slow down"}}'
),
Exception(
'OpenrouterException - {"error":{"message":"Provider returned error",'
'"code":429}}'
),
BusyError(request_id="thread-busy-parity"),
Exception("Thread is busy with another request"),
]
@pytest.mark.parametrize("exc", _classify_cases())
def test_classify_stream_exception_matches_old_implementation(
exc: Exception,
) -> None:
new = new_classify(exc, flow_label="parity-test")
old = old_classify(exc, flow_label="parity-test")
# Strip the wall-clock retry timestamp before comparing — both
# implementations call ``time.time()`` independently and the call
# order is enough to differ by 1 ms in practice. Every other field
# in the tuple must match exactly.
new_extra = dict(new[5]) if isinstance(new[5], dict) else new[5]
old_extra = dict(old[5]) if isinstance(old[5], dict) else old[5]
if isinstance(new_extra, dict) and isinstance(old_extra, dict):
new_extra.pop("retry_after_at", None)
old_extra.pop("retry_after_at", None)
assert new[:5] == old[:5]
assert new_extra == old_extra
def test_classify_turn_cancelling_branch_parity() -> None:
"""The TURN_CANCELLING branch reads cancel state for the busy thread id;
both implementations must agree on retry-window semantics, not just the
plain THREAD_BUSY code."""
thread_id = "parity-cancelling-thread"
reset_cancel(thread_id)
request_cancel(thread_id)
exc = BusyError(request_id=thread_id)
new = new_classify(exc, flow_label="parity-test")
old = old_classify(exc, flow_label="parity-test")
assert new[0] == old[0] == "thread_busy"
assert new[1] == old[1] == "TURN_CANCELLING"
assert isinstance(new[5], dict) and isinstance(old[5], dict)
assert new[5]["retry_after_ms"] == old[5]["retry_after_ms"]
# ------------------------------------------------------------ terminal emitter
class _FakeStreamingService:
"""Duck-types ``format_error`` for both old and new emitters."""
def __init__(self) -> None:
self.calls: list[dict[str, Any]] = []
def format_error(
self, message: str, *, error_code: str, extra: dict[str, Any] | None = None
) -> str:
self.calls.append(
{"message": message, "error_code": error_code, "extra": extra}
)
return f"data: {{\"type\":\"error\",\"errorText\":\"{message}\"}}\n\n"
def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None:
"""The new emitter must produce the same SSE frame and log the same
structured payload as the old one for the same arguments."""
args: dict[str, Any] = {
"flow": "new",
"request_id": "req-parity",
"thread_id": 7,
"search_space_id": 9,
"user_id": "user-parity",
"message": "boom",
"error_kind": "server_error",
"error_code": "SERVER_ERROR",
"severity": "error",
"is_expected": False,
"extra": {"foo": "bar"},
}
new_svc = _FakeStreamingService()
old_svc = _FakeStreamingService()
with caplog.at_level(logging.ERROR):
new_frame = new_emit_terminal_error(streaming_service=new_svc, **args)
old_frame = old_emit_terminal_error(streaming_service=old_svc, **args)
assert new_frame == old_frame
assert new_svc.calls == old_svc.calls
chat_error_records = [
r for r in caplog.records if "[chat_stream_error]" in r.message
]
# One log line per emit call (two emits -> two records).
assert len(chat_error_records) == 2
# ---------------------------------------------------------------- tool output
def test_tool_output_helpers_match_old_implementation() -> None:
samples: list[Any] = [
{"result": "ok"},
{"error": "bad"},
{"result": "Error: x"},
"Error: plain",
"fine",
{"nested": {"a": 1}},
]
for s in samples:
assert new_tool_output_to_text(s) == old_tool_output_to_text(s)
assert new_tool_output_has_error(s) == old_tool_output_has_error(s)
assert new_extract_resolved_file_path(
tool_name="write_file",
tool_output={"path": " /tmp/x "},
tool_input=None,
) == old_extract_resolved_file_path(
tool_name="write_file",
tool_output={"path": " /tmp/x "},
tool_input=None,
)
assert new_extract_resolved_file_path(
tool_name="write_file",
tool_output={},
tool_input={"file_path": " /fallback "},
) == old_extract_resolved_file_path(
tool_name="write_file",
tool_output={},
tool_input={"file_path": " /fallback "},
)

View file

@ -0,0 +1,240 @@
"""Parity tests for Stage 2 extractions (tool matching, thinking step, custom events)."""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from app.tasks.chat.stream_new_chat import _legacy_match_lc_id as old_legacy_match
from app.tasks.chat.streaming.handlers.custom_events import (
handle_action_log,
handle_action_log_updated,
handle_document_created,
handle_report_progress,
)
from app.tasks.chat.streaming.helpers.tool_call_matching import (
match_buffered_langchain_tool_call_id as new_legacy_match,
)
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
from app.tasks.chat.streaming.relay.thinking_step_completion import (
complete_active_thinking_step,
)
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
pytestmark = pytest.mark.unit
def _copy_chunk_buffer(raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
return [dict(x) for x in raw]
def test_legacy_tool_call_match_matches_old_implementation() -> None:
cases: list[tuple[list[dict[str, Any]], str, str, dict[str, str]]] = [
(
[
{"name": "write_file", "id": "lc-a"},
{"name": "other", "id": "lc-b"},
],
"write_file",
"run-1",
{},
),
(
[{"name": "x", "id": None}, {"name": "y", "id": "lc-fallback"}],
"write_file",
"run-2",
{},
),
([{"name": "no_id"}], "write_file", "run-3", {}),
]
for chunks_template, tool_name, run_id, lc_map_seed in cases:
old_chunks = _copy_chunk_buffer(chunks_template)
new_chunks = _copy_chunk_buffer(chunks_template)
old_map = dict(lc_map_seed)
new_map = dict(lc_map_seed)
old_out = old_legacy_match(old_chunks, tool_name, run_id, old_map)
new_out = new_legacy_match(new_chunks, tool_name, run_id, new_map)
assert new_out == old_out
assert new_chunks == old_chunks
assert new_map == old_map
def test_emit_thinking_step_frame_invokes_builder_before_service() -> None:
order: list[str] = []
builder = MagicMock()
def on_ts(*args: Any, **kwargs: Any) -> None:
order.append("builder")
builder.on_thinking_step.side_effect = on_ts
svc = MagicMock()
def fmt(**kwargs: Any) -> str:
order.append("service")
return "frame"
svc.format_thinking_step.side_effect = fmt
out = emit_thinking_step_frame(
streaming_service=svc,
content_builder=builder,
step_id="thinking-1",
title="Working",
status="in_progress",
items=["a"],
)
assert out == "frame"
assert order == ["builder", "service"]
builder.on_thinking_step.assert_called_once()
svc.format_thinking_step.assert_called_once()
def test_emit_thinking_step_frame_skips_builder_when_none() -> None:
svc = MagicMock(return_value="x")
svc.format_thinking_step.return_value = "frame"
assert (
emit_thinking_step_frame(
streaming_service=svc,
content_builder=None,
step_id="s",
title="t",
)
== "frame"
)
svc.format_thinking_step.assert_called_once()
def test_complete_active_thinking_step_mirrors_closure_semantics() -> None:
svc = MagicMock()
svc.format_thinking_step.return_value = "done-frame"
completed: set[str] = set()
frame, new_id = complete_active_thinking_step(
streaming_service=svc,
content_builder=None,
last_active_step_id="thinking-1",
last_active_step_title="T",
last_active_step_items=["x"],
completed_step_ids=completed,
)
assert frame == "done-frame"
assert new_id is None
assert "thinking-1" in completed
frame2, id2 = complete_active_thinking_step(
streaming_service=svc,
content_builder=None,
last_active_step_id="thinking-1",
last_active_step_title="T",
last_active_step_items=[],
completed_step_ids=completed,
)
assert frame2 is None
assert id2 == "thinking-1"
def test_agent_event_relay_state_factory_matches_counter_rule() -> None:
s0 = AgentEventRelayState.for_invocation(parity_v2=False)
assert s0.thinking_step_counter == 0
assert s0.last_active_step_id is None
s1 = AgentEventRelayState.for_invocation(
initial_step_id="thinking-resume-1",
initial_step_title="Inherited",
initial_step_items=["Topic: X"],
parity_v2=True,
)
assert s1.thinking_step_counter == 1
assert s1.last_active_step_id == "thinking-resume-1"
assert s1.parity_v2 is True
assert s1.next_thinking_step_id("thinking") == "thinking-2"
@pytest.mark.parametrize(
("phase", "message", "start_items", "expected_tail"),
[
(
"revising_section",
"progress line",
["Topic: Foo", "Modifying bar", "stale..."],
["Topic: Foo", "Modifying bar", "progress line"],
),
(
"other",
"phase msg",
["Topic: Foo", "old line"],
["Topic: Foo", "phase msg"],
),
],
)
def test_report_progress_items_match_reference(
phase: str,
message: str,
start_items: list[str],
expected_tail: list[str],
) -> None:
svc = MagicMock()
svc.format_thinking_step.return_value = "sse"
items = list(start_items)
frame, new_items = handle_report_progress(
{"message": message, "phase": phase},
last_active_step_id="step-1",
last_active_step_title="Report",
last_active_step_items=items,
streaming_service=svc,
content_builder=None,
)
assert frame == "sse"
assert new_items == expected_tail
kwargs = svc.format_thinking_step.call_args.kwargs
assert kwargs["items"] == expected_tail
def test_report_progress_noop_when_missing_message_or_step() -> None:
svc = MagicMock()
items = ["Topic: A"]
f1, i1 = handle_report_progress(
{"message": "", "phase": "x"},
last_active_step_id="s",
last_active_step_title="t",
last_active_step_items=items,
streaming_service=svc,
content_builder=None,
)
assert f1 is None and i1 is items
f2, i2 = handle_report_progress(
{"message": "m", "phase": "x"},
last_active_step_id=None,
last_active_step_title="t",
last_active_step_items=items,
streaming_service=svc,
content_builder=None,
)
assert f2 is None and i2 is items
def test_document_action_handlers_match_format_data_guards() -> None:
svc = MagicMock()
svc.format_data.return_value = "data-frame"
assert handle_document_created({}, streaming_service=svc) is None
assert handle_document_created({"id": 0}, streaming_service=svc) is None
handle_document_created({"id": 42, "title": "x"}, streaming_service=svc)
svc.format_data.assert_called_with(
"documents-updated", {"action": "created", "document": {"id": 42, "title": "x"}}
)
svc.reset_mock()
assert handle_action_log({"id": None}, streaming_service=svc) is None
handle_action_log({"id": 1}, streaming_service=svc)
svc.format_data.assert_called_once_with("action-log", {"id": 1})
svc.reset_mock()
assert handle_action_log_updated({"id": None}, streaming_service=svc) is None
handle_action_log_updated({"id": 2}, streaming_service=svc)
svc.format_data.assert_called_once_with("action-log-updated", {"id": 2})