From 8b6ffd12b8649bd789a9e780dd90a0a64d04fbac Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Wed, 6 May 2026 20:08:48 +0200 Subject: [PATCH] Add parity unit tests for extracted chat streaming vs legacy. --- .../unit/tasks/chat/streaming/__init__.py | 0 .../chat/streaming/test_stage_1_parity.py | 292 ++++++++++++++++++ .../chat/streaming/test_stage_2_parity.py | 240 ++++++++++++++ 3 files changed, 532 insertions(+) create mode 100644 surfsense_backend/tests/unit/tasks/chat/streaming/__init__.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/__init__.py b/surfsense_backend/tests/unit/tasks/chat/streaming/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py new file mode 100644 index 000000000..9207f37d1 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_1_parity.py @@ -0,0 +1,292 @@ +"""Pin Stage 1 extractions as faithful copies of the old helpers. + +The new orchestrator under ``app.tasks.chat.streaming`` is built in +parallel with the production module ``app.tasks.chat.stream_new_chat``. +For each Stage 1 extraction we assert the new function returns the same +output as the old one for a representative input set. The moment the +two diverge - intentionally or otherwise - this file fails loudly so +the divergence is reviewed rather than shipped silently. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any + +import pytest + +from app.agents.new_chat.errors import BusyError +from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel +from app.tasks.chat.stream_new_chat import ( + _classify_stream_exception as old_classify, + _emit_stream_terminal_error as old_emit_terminal_error, + _extract_chunk_parts as old_extract_chunk_parts, + _extract_resolved_file_path as old_extract_resolved_file_path, + _first_interrupt_value as old_first_interrupt_value, + _tool_output_has_error as old_tool_output_has_error, + _tool_output_to_text as old_tool_output_to_text, +) +from app.tasks.chat.streaming.errors.classifier import ( + classify_stream_exception as new_classify, +) +from app.tasks.chat.streaming.errors.emitter import ( + emit_stream_terminal_error as new_emit_terminal_error, +) +from app.tasks.chat.streaming.helpers.chunk_parts import ( + extract_chunk_parts as new_extract_chunk_parts, +) +from app.tasks.chat.streaming.helpers.interrupt_inspector import ( + first_interrupt_value as new_first_interrupt_value, +) +from app.tasks.chat.streaming.helpers.tool_output import ( + extract_resolved_file_path as new_extract_resolved_file_path, + tool_output_has_error as new_tool_output_has_error, + tool_output_to_text as new_tool_output_to_text, +) + +pytestmark = pytest.mark.unit + + +# ---------------------------------------------------------------- chunk parts + + +@dataclass +class _Chunk: + content: Any = "" + additional_kwargs: dict[str, Any] = field(default_factory=dict) + tool_call_chunks: list[dict[str, Any]] = field(default_factory=list) + + +_CHUNK_CASES: list[Any] = [ + None, + _Chunk(content=""), + _Chunk(content="hello"), + _Chunk(content=42), # invalid type, defensively coerced to empty + _Chunk( + content=[ + {"type": "text", "text": "Hello "}, + {"type": "text", "text": "world"}, + ] + ), + _Chunk( + content=[ + {"type": "reasoning", "reasoning": "hmm "}, + {"type": "reasoning", "text": "still"}, + {"type": "text", "text": "answer"}, + ] + ), + _Chunk( + content=[ + {"type": "tool_call_chunk", "id": "c1", "name": "x", "args": "{"}, + {"type": "tool_use", "id": "c2", "name": "y"}, + {"type": "image_url", "url": "ignored"}, + ] + ), + _Chunk( + content="visible", + additional_kwargs={"reasoning_content": "private"}, + ), + _Chunk( + tool_call_chunks=[ + {"id": None, "name": None, "args": '{"a":1}', "index": 0}, + {"id": "c", "name": "n", "args": "}", "index": 0}, + ] + ), + _Chunk( + content=[{"type": "tool_call_chunk", "id": "from-block", "name": "x"}], + tool_call_chunks=[{"id": "from-attr", "name": "y"}], + ), +] + + +@pytest.mark.parametrize("chunk", _CHUNK_CASES) +def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None: + assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk) + + +# ---------------------------------------------------------- interrupt inspector + + +@dataclass +class _Interrupt: + value: dict[str, Any] + + +@dataclass +class _Task: + interrupts: tuple[Any, ...] = () + + +@dataclass +class _State: + tasks: tuple[Any, ...] = () + interrupts: tuple[Any, ...] = () + + +_INTERRUPT_CASES: list[Any] = [ + _State(), + _State(tasks=(_Task(interrupts=(_Interrupt(value={"name": "send"}),)),)), + # Multiple tasks: must return the FIRST one in iteration order. + _State( + tasks=( + _Task(interrupts=(_Interrupt(value={"name": "first"}),)), + _Task(interrupts=(_Interrupt(value={"name": "second"}),)), + ) + ), + # Empty task interrupts -> falls back to root state.interrupts. + _State( + tasks=(_Task(interrupts=()),), + interrupts=(_Interrupt(value={"name": "root"}),), + ), + # Interrupts as plain dicts (not wrapper objects). + _State(interrupts=({"value": {"name": "dict_root"}},)), + # A defective task whose `.interrupts` raises - must be tolerated. + _State(tasks=(object(),)), +] + + +@pytest.mark.parametrize("state", _INTERRUPT_CASES) +def test_first_interrupt_value_matches_old_implementation(state: Any) -> None: + assert new_first_interrupt_value(state) == old_first_interrupt_value(state) + + +# ----------------------------------------------------------- error classifier + + +def _classify_cases() -> list[Exception]: + """Inputs that the FE depends on being mapped to specific error codes.""" + return [ + Exception("totally generic error"), + Exception( + '{"error":{"type":"rate_limit_error","message":"slow down"}}' + ), + Exception( + 'OpenrouterException - {"error":{"message":"Provider returned error",' + '"code":429}}' + ), + BusyError(request_id="thread-busy-parity"), + Exception("Thread is busy with another request"), + ] + + +@pytest.mark.parametrize("exc", _classify_cases()) +def test_classify_stream_exception_matches_old_implementation( + exc: Exception, +) -> None: + new = new_classify(exc, flow_label="parity-test") + old = old_classify(exc, flow_label="parity-test") + # Strip the wall-clock retry timestamp before comparing — both + # implementations call ``time.time()`` independently and the call + # order is enough to differ by 1 ms in practice. Every other field + # in the tuple must match exactly. + new_extra = dict(new[5]) if isinstance(new[5], dict) else new[5] + old_extra = dict(old[5]) if isinstance(old[5], dict) else old[5] + if isinstance(new_extra, dict) and isinstance(old_extra, dict): + new_extra.pop("retry_after_at", None) + old_extra.pop("retry_after_at", None) + assert new[:5] == old[:5] + assert new_extra == old_extra + + +def test_classify_turn_cancelling_branch_parity() -> None: + """The TURN_CANCELLING branch reads cancel state for the busy thread id; + both implementations must agree on retry-window semantics, not just the + plain THREAD_BUSY code.""" + thread_id = "parity-cancelling-thread" + reset_cancel(thread_id) + request_cancel(thread_id) + exc = BusyError(request_id=thread_id) + new = new_classify(exc, flow_label="parity-test") + old = old_classify(exc, flow_label="parity-test") + assert new[0] == old[0] == "thread_busy" + assert new[1] == old[1] == "TURN_CANCELLING" + assert isinstance(new[5], dict) and isinstance(old[5], dict) + assert new[5]["retry_after_ms"] == old[5]["retry_after_ms"] + + +# ------------------------------------------------------------ terminal emitter + + +class _FakeStreamingService: + """Duck-types ``format_error`` for both old and new emitters.""" + + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + + def format_error( + self, message: str, *, error_code: str, extra: dict[str, Any] | None = None + ) -> str: + self.calls.append( + {"message": message, "error_code": error_code, "extra": extra} + ) + return f"data: {{\"type\":\"error\",\"errorText\":\"{message}\"}}\n\n" + + +def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None: + """The new emitter must produce the same SSE frame and log the same + structured payload as the old one for the same arguments.""" + args: dict[str, Any] = { + "flow": "new", + "request_id": "req-parity", + "thread_id": 7, + "search_space_id": 9, + "user_id": "user-parity", + "message": "boom", + "error_kind": "server_error", + "error_code": "SERVER_ERROR", + "severity": "error", + "is_expected": False, + "extra": {"foo": "bar"}, + } + + new_svc = _FakeStreamingService() + old_svc = _FakeStreamingService() + + with caplog.at_level(logging.ERROR): + new_frame = new_emit_terminal_error(streaming_service=new_svc, **args) + old_frame = old_emit_terminal_error(streaming_service=old_svc, **args) + + assert new_frame == old_frame + assert new_svc.calls == old_svc.calls + chat_error_records = [ + r for r in caplog.records if "[chat_stream_error]" in r.message + ] + # One log line per emit call (two emits -> two records). + assert len(chat_error_records) == 2 + + +# ---------------------------------------------------------------- tool output + + +def test_tool_output_helpers_match_old_implementation() -> None: + samples: list[Any] = [ + {"result": "ok"}, + {"error": "bad"}, + {"result": "Error: x"}, + "Error: plain", + "fine", + {"nested": {"a": 1}}, + ] + for s in samples: + assert new_tool_output_to_text(s) == old_tool_output_to_text(s) + assert new_tool_output_has_error(s) == old_tool_output_has_error(s) + + assert new_extract_resolved_file_path( + tool_name="write_file", + tool_output={"path": " /tmp/x "}, + tool_input=None, + ) == old_extract_resolved_file_path( + tool_name="write_file", + tool_output={"path": " /tmp/x "}, + tool_input=None, + ) + assert new_extract_resolved_file_path( + tool_name="write_file", + tool_output={}, + tool_input={"file_path": " /fallback "}, + ) == old_extract_resolved_file_path( + tool_name="write_file", + tool_output={}, + tool_input={"file_path": " /fallback "}, + ) diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py new file mode 100644 index 000000000..892bb7a6a --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_stage_2_parity.py @@ -0,0 +1,240 @@ +"""Parity tests for Stage 2 extractions (tool matching, thinking step, custom events).""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from app.tasks.chat.stream_new_chat import _legacy_match_lc_id as old_legacy_match +from app.tasks.chat.streaming.handlers.custom_events import ( + handle_action_log, + handle_action_log_updated, + handle_document_created, + handle_report_progress, +) +from app.tasks.chat.streaming.helpers.tool_call_matching import ( + match_buffered_langchain_tool_call_id as new_legacy_match, +) +from app.tasks.chat.streaming.relay.state import AgentEventRelayState +from app.tasks.chat.streaming.relay.thinking_step_completion import ( + complete_active_thinking_step, +) +from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame + +pytestmark = pytest.mark.unit + + +def _copy_chunk_buffer(raw: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [dict(x) for x in raw] + + +def test_legacy_tool_call_match_matches_old_implementation() -> None: + cases: list[tuple[list[dict[str, Any]], str, str, dict[str, str]]] = [ + ( + [ + {"name": "write_file", "id": "lc-a"}, + {"name": "other", "id": "lc-b"}, + ], + "write_file", + "run-1", + {}, + ), + ( + [{"name": "x", "id": None}, {"name": "y", "id": "lc-fallback"}], + "write_file", + "run-2", + {}, + ), + ([{"name": "no_id"}], "write_file", "run-3", {}), + ] + for chunks_template, tool_name, run_id, lc_map_seed in cases: + old_chunks = _copy_chunk_buffer(chunks_template) + new_chunks = _copy_chunk_buffer(chunks_template) + old_map = dict(lc_map_seed) + new_map = dict(lc_map_seed) + old_out = old_legacy_match(old_chunks, tool_name, run_id, old_map) + new_out = new_legacy_match(new_chunks, tool_name, run_id, new_map) + assert new_out == old_out + assert new_chunks == old_chunks + assert new_map == old_map + + +def test_emit_thinking_step_frame_invokes_builder_before_service() -> None: + order: list[str] = [] + builder = MagicMock() + + def on_ts(*args: Any, **kwargs: Any) -> None: + order.append("builder") + + builder.on_thinking_step.side_effect = on_ts + + svc = MagicMock() + + def fmt(**kwargs: Any) -> str: + order.append("service") + return "frame" + + svc.format_thinking_step.side_effect = fmt + + out = emit_thinking_step_frame( + streaming_service=svc, + content_builder=builder, + step_id="thinking-1", + title="Working", + status="in_progress", + items=["a"], + ) + assert out == "frame" + assert order == ["builder", "service"] + builder.on_thinking_step.assert_called_once() + svc.format_thinking_step.assert_called_once() + + +def test_emit_thinking_step_frame_skips_builder_when_none() -> None: + svc = MagicMock(return_value="x") + svc.format_thinking_step.return_value = "frame" + assert ( + emit_thinking_step_frame( + streaming_service=svc, + content_builder=None, + step_id="s", + title="t", + ) + == "frame" + ) + svc.format_thinking_step.assert_called_once() + + +def test_complete_active_thinking_step_mirrors_closure_semantics() -> None: + svc = MagicMock() + svc.format_thinking_step.return_value = "done-frame" + completed: set[str] = set() + + frame, new_id = complete_active_thinking_step( + streaming_service=svc, + content_builder=None, + last_active_step_id="thinking-1", + last_active_step_title="T", + last_active_step_items=["x"], + completed_step_ids=completed, + ) + assert frame == "done-frame" + assert new_id is None + assert "thinking-1" in completed + + frame2, id2 = complete_active_thinking_step( + streaming_service=svc, + content_builder=None, + last_active_step_id="thinking-1", + last_active_step_title="T", + last_active_step_items=[], + completed_step_ids=completed, + ) + assert frame2 is None + assert id2 == "thinking-1" + + +def test_agent_event_relay_state_factory_matches_counter_rule() -> None: + s0 = AgentEventRelayState.for_invocation(parity_v2=False) + assert s0.thinking_step_counter == 0 + assert s0.last_active_step_id is None + + s1 = AgentEventRelayState.for_invocation( + initial_step_id="thinking-resume-1", + initial_step_title="Inherited", + initial_step_items=["Topic: X"], + parity_v2=True, + ) + assert s1.thinking_step_counter == 1 + assert s1.last_active_step_id == "thinking-resume-1" + assert s1.parity_v2 is True + assert s1.next_thinking_step_id("thinking") == "thinking-2" + + +@pytest.mark.parametrize( + ("phase", "message", "start_items", "expected_tail"), + [ + ( + "revising_section", + "progress line", + ["Topic: Foo", "Modifying bar", "stale..."], + ["Topic: Foo", "Modifying bar", "progress line"], + ), + ( + "other", + "phase msg", + ["Topic: Foo", "old line"], + ["Topic: Foo", "phase msg"], + ), + ], +) +def test_report_progress_items_match_reference( + phase: str, + message: str, + start_items: list[str], + expected_tail: list[str], +) -> None: + svc = MagicMock() + svc.format_thinking_step.return_value = "sse" + + items = list(start_items) + frame, new_items = handle_report_progress( + {"message": message, "phase": phase}, + last_active_step_id="step-1", + last_active_step_title="Report", + last_active_step_items=items, + streaming_service=svc, + content_builder=None, + ) + assert frame == "sse" + assert new_items == expected_tail + kwargs = svc.format_thinking_step.call_args.kwargs + assert kwargs["items"] == expected_tail + + +def test_report_progress_noop_when_missing_message_or_step() -> None: + svc = MagicMock() + items = ["Topic: A"] + f1, i1 = handle_report_progress( + {"message": "", "phase": "x"}, + last_active_step_id="s", + last_active_step_title="t", + last_active_step_items=items, + streaming_service=svc, + content_builder=None, + ) + assert f1 is None and i1 is items + + f2, i2 = handle_report_progress( + {"message": "m", "phase": "x"}, + last_active_step_id=None, + last_active_step_title="t", + last_active_step_items=items, + streaming_service=svc, + content_builder=None, + ) + assert f2 is None and i2 is items + + +def test_document_action_handlers_match_format_data_guards() -> None: + svc = MagicMock() + svc.format_data.return_value = "data-frame" + + assert handle_document_created({}, streaming_service=svc) is None + assert handle_document_created({"id": 0}, streaming_service=svc) is None + handle_document_created({"id": 42, "title": "x"}, streaming_service=svc) + svc.format_data.assert_called_with( + "documents-updated", {"action": "created", "document": {"id": 42, "title": "x"}} + ) + + svc.reset_mock() + assert handle_action_log({"id": None}, streaming_service=svc) is None + handle_action_log({"id": 1}, streaming_service=svc) + svc.format_data.assert_called_once_with("action-log", {"id": 1}) + + svc.reset_mock() + assert handle_action_log_updated({"id": None}, streaming_service=svc) is None + handle_action_log_updated({"id": 2}, streaming_service=svc) + svc.format_data.assert_called_once_with("action-log-updated", {"id": 2})