test(chat): add parity tests for streaming/flows/ parallel refactor

Adds 34 tests under tests/unit/tasks/chat/streaming/ that cover the
new flows tree against the legacy stream_new_chat.py module to gate
the upcoming cutover. Coverage:

* Public entry points: stream_new_chat and stream_resume_chat are
  async generator functions whose parameter signatures (name, kind,
  annotation, default) match the legacy versions one-for-one. Uses a
  normalized-annotation comparison so PEP-563 vs eager-annotation
  representation differences are tolerated.
* Extracted helpers: image-capability gate, runtime-context builders
  for new-chat and resume-chat, LLM-bundle dispatcher, premium-quota
  needs check + reservation dataclass, rate-limit recovery truth
  table, persistence-spawn registration/self-unregistration, await
  helpers.
* SSE frame iterators: iter_initial_frames + iter_final_frames emit
  the canonical sequence; iter_token_usage_frame skips on None.
* Initial thinking step: 4 parametrized branches (text, image-only,
  empty, mentioned-docs), long-query truncation, many-docs collapse.

These tests are scaffolding for the cutover and will be removed once
the legacy module is deleted.
This commit is contained in:
CREDO23 2026-05-25 21:50:18 +02:00
parent cf0085575c
commit cfdad85058

View file

@ -0,0 +1,582 @@
"""Parity gate for the parallel refactor of ``stream_new_chat.py``.
The new tree under ``app.tasks.chat.streaming.flows`` is built side-by-side with
the legacy monolithic ``app.tasks.chat.stream_new_chat`` so we can cut over
atomically. This file pins externally-observable behaviour at module
boundaries so a divergence between the two trees fails loudly *before* the
cutover.
What we verify:
1. **Signature parity** ``stream_new_chat`` / ``stream_resume_chat`` from
the new tree have the same call signature as the originals.
2. **Helper extraction parity** the SRP modules in ``flows/`` produce the
same outputs as the inline code in the legacy file for representative
inputs (initial thinking step, image-capability gate, runtime context,
SSE frame sequences, token-usage frame shape, persistence guards).
3. **Wrapper delegation** wrappers like ``load_llm_bundle`` /
``can_recover_provider_rate_limit`` exist and are addressable.
Delete this file along with ``stream_new_chat.py`` once the cutover is done
(see the parent refactor plan).
"""
from __future__ import annotations
import asyncio
import inspect
from dataclasses import dataclass
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from app.agents.new_chat.context import SurfSenseContextSchema
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_new_chat import (
stream_new_chat as old_stream_new_chat,
stream_resume_chat as old_stream_resume_chat,
)
from app.tasks.chat.streaming.flows import (
stream_new_chat as new_stream_new_chat,
stream_resume_chat as new_stream_resume_chat,
)
from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import (
build_initial_thinking_step,
)
from app.tasks.chat.streaming.flows.new_chat.llm_capability import (
check_image_input_capability,
)
from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import (
await_persist_task,
spawn_persist_assistant_shell_task,
spawn_persist_user_task,
spawn_set_ai_responding_bg,
)
from app.tasks.chat.streaming.flows.new_chat.runtime_context import (
build_new_chat_runtime_context,
)
from app.tasks.chat.streaming.flows.resume_chat.runtime_context import (
build_resume_chat_runtime_context,
)
from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame
from app.tasks.chat.streaming.flows.shared.first_frames import (
iter_final_frames,
iter_initial_frames,
)
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
from app.tasks.chat.streaming.flows.shared.premium_quota import (
PremiumReservation,
needs_premium_quota,
)
from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import (
can_recover_provider_rate_limit,
)
pytestmark = pytest.mark.unit
# --------------------------------------------------------------------- signature
def _normalize_annotation(ann: Any) -> str:
"""Compare-friendly form for an annotation.
The legacy ``stream_new_chat.py`` does NOT use ``from __future__ import
annotations``, so its annotations are evaluated at import time and come
back as type objects / typing generics. The new tree DOES use it, so its
annotations are PEP-563 strings.
Both reprs describe the same types strip the module prefixes / typing
namespace + the ``<class 'X'>`` wrapper so we compare the canonical
declared form.
"""
if ann is inspect.Signature.empty:
return ""
raw = ann if isinstance(ann, str) else repr(ann)
cleaned = (
raw.replace("typing.", "")
.replace("collections.abc.", "")
.replace("app.db.", "")
.replace("app.agents.new_chat.filesystem_selection.", "")
.replace("app.agents.new_chat.context.", "")
)
# Unwrap ``<class 'int'>`` → ``int`` (legacy-side type objects).
if cleaned.startswith("<class '") and cleaned.endswith("'>"):
cleaned = cleaned[len("<class '") : -len("'>")]
return cleaned
def _normalize_sig(sig: inspect.Signature) -> list[tuple[str, Any, str]]:
return [
(p.name, p.default, _normalize_annotation(p.annotation))
for p in sig.parameters.values()
]
def test_stream_new_chat_signature_matches_legacy() -> None:
old = inspect.signature(old_stream_new_chat)
new = inspect.signature(new_stream_new_chat)
assert _normalize_sig(new) == _normalize_sig(old)
assert _normalize_annotation(new.return_annotation) == _normalize_annotation(
old.return_annotation
)
def test_stream_resume_chat_signature_matches_legacy() -> None:
old = inspect.signature(old_stream_resume_chat)
new = inspect.signature(new_stream_resume_chat)
assert _normalize_sig(new) == _normalize_sig(old)
assert _normalize_annotation(new.return_annotation) == _normalize_annotation(
old.return_annotation
)
def test_orchestrators_are_async_generator_functions() -> None:
assert inspect.isasyncgenfunction(new_stream_new_chat)
assert inspect.isasyncgenfunction(new_stream_resume_chat)
# ------------------------------------------------------------ initial thinking
@dataclass
class _FakeSurfsenseDoc:
"""Stand-in for ``SurfsenseDocsDocument`` with just the field we read."""
title: str
@pytest.mark.parametrize(
"user_query, image_urls, docs, expected_title, expected_action",
[
("hello world", None, [], "Understanding your request", "Processing"),
("", ["data:image/png;base64,AAA"], [], "Understanding your request", "Processing"),
("", None, [], "Understanding your request", "Processing"),
(
"doc question",
None,
[_FakeSurfsenseDoc(title="My Doc")],
"Analyzing referenced content",
"Analyzing",
),
],
)
def test_initial_thinking_step_branches(
user_query: str,
image_urls: list[str] | None,
docs: list[Any],
expected_title: str,
expected_action: str,
) -> None:
step = build_initial_thinking_step(
user_query=user_query,
user_image_data_urls=image_urls,
mentioned_surfsense_docs=docs, # type: ignore[arg-type]
)
assert step.step_id == "thinking-1"
assert step.title == expected_title
assert len(step.items) == 1
assert step.items[0].startswith(f"{expected_action}: ")
def test_initial_thinking_step_truncates_long_query() -> None:
long_query = "x" * 200
step = build_initial_thinking_step(
user_query=long_query,
user_image_data_urls=None,
mentioned_surfsense_docs=[],
)
# 80-char truncation + ellipsis, sandwiched after "Processing: ".
assert "..." in step.items[0]
item = step.items[0]
payload = item[len("Processing: ") :]
assert payload.startswith("x" * 80) and payload.endswith("...")
def test_initial_thinking_step_collapses_many_doc_names() -> None:
docs = [_FakeSurfsenseDoc(title=f"Doc {i}") for i in range(5)]
step = build_initial_thinking_step(
user_query="q",
user_image_data_urls=None,
mentioned_surfsense_docs=docs, # type: ignore[arg-type]
)
assert "[5 docs]" in step.items[0]
# ------------------------------------------------------------ capability gate
def test_image_capability_passes_without_images() -> None:
assert check_image_input_capability(
user_image_data_urls=None, agent_config=None
) is None
def test_image_capability_passes_when_capability_unknown() -> None:
"""Unknown / unmapped models are not blocked — only models LiteLLM has
*explicitly* marked text-only trip the gate."""
class _AgentConfig:
provider = "openrouter"
model_name = "unknown-mystery-model"
custom_provider = None
config_name = "Unknown"
litellm_params: dict[str, Any] = {}
with patch(
"app.services.provider_capabilities.is_known_text_only_chat_model",
return_value=False,
):
assert (
check_image_input_capability(
user_image_data_urls=["data:image/png;base64,AAA"],
agent_config=_AgentConfig(), # type: ignore[arg-type]
)
is None
)
def test_image_capability_blocks_known_text_only_models() -> None:
class _AgentConfig:
provider = "openai"
model_name = "gpt-3.5-turbo"
custom_provider = None
config_name = "GPT-3.5"
litellm_params: dict[str, Any] = {"base_model": "gpt-3.5-turbo"}
with patch(
"app.services.provider_capabilities.is_known_text_only_chat_model",
return_value=True,
):
result = check_image_input_capability(
user_image_data_urls=["data:image/png;base64,AAA"],
agent_config=_AgentConfig(), # type: ignore[arg-type]
)
assert result is not None
message, error_code = result
assert error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
assert "GPT-3.5" in message
# ---------------------------------------------------------------- runtime ctx
def test_new_chat_runtime_context_prefers_accepted_folder_ids() -> None:
ctx = build_new_chat_runtime_context(
search_space_id=7,
mentioned_document_ids=[1, 2],
accepted_folder_ids=[10],
mentioned_folder_ids=[20, 30],
request_id="req",
turn_id="t1",
)
assert isinstance(ctx, SurfSenseContextSchema)
assert ctx.search_space_id == 7
assert list(ctx.mentioned_document_ids) == [1, 2]
assert list(ctx.mentioned_folder_ids) == [10]
assert ctx.request_id == "req"
assert ctx.turn_id == "t1"
def test_new_chat_runtime_context_falls_back_to_mentioned_folder_ids() -> None:
ctx = build_new_chat_runtime_context(
search_space_id=7,
mentioned_document_ids=None,
accepted_folder_ids=[],
mentioned_folder_ids=[20, 30],
request_id=None,
turn_id="t2",
)
assert list(ctx.mentioned_folder_ids) == [20, 30]
def test_resume_chat_runtime_context_empty_mention_lists() -> None:
ctx = build_resume_chat_runtime_context(
search_space_id=42, request_id="req-r", turn_id="t-r"
)
assert ctx.search_space_id == 42
assert ctx.request_id == "req-r"
assert ctx.turn_id == "t-r"
# ---------------------------------------------------------------- SSE frames
def test_iter_initial_frames_emits_canonical_sequence() -> None:
svc = VercelStreamingService()
frames = list(iter_initial_frames(svc, turn_id="42:1700000000000"))
# Exactly 4 frames: message_start, start_step, turn-info (turn_id), turn-status (busy).
assert len(frames) == 4
assert "42:1700000000000" in frames[2]
assert '"status":"busy"' in frames[3] or '"status": "busy"' in frames[3]
def test_iter_final_frames_emits_idle_then_finish_done() -> None:
svc = VercelStreamingService()
frames = list(iter_final_frames(svc))
assert len(frames) == 4
assert '"status":"idle"' in frames[0] or '"status": "idle"' in frames[0]
# ----------------------------------------------------------- token usage frame
class _FakeAccumulator:
"""Minimal stand-in covering only the fields ``iter_token_usage_frame`` reads."""
def __init__(self, summary: Any = None) -> None:
self._summary = summary
self.calls = [1, 2, 3]
self.grand_total = 100
self.total_cost_micros = 50_000
self.total_prompt_tokens = 60
self.total_completion_tokens = 40
def per_message_summary(self) -> Any:
return self._summary
def serialized_calls(self) -> list[Any]:
return list(self.calls)
def test_token_usage_frame_skipped_when_no_summary() -> None:
svc = VercelStreamingService()
frames = list(
iter_token_usage_frame(
svc,
accumulator=_FakeAccumulator(summary=None), # type: ignore[arg-type]
log_label="parity-empty",
)
)
assert frames == []
def test_token_usage_frame_emitted_when_summary_present() -> None:
svc = VercelStreamingService()
frames = list(
iter_token_usage_frame(
svc,
accumulator=_FakeAccumulator(summary=[{"m": "x", "t": 100}]), # type: ignore[arg-type]
log_label="parity-populated",
)
)
assert len(frames) == 1
# Field shape on the wire is fixed by the FE; assert each surfaces.
payload = frames[0]
for key in (
'"prompt_tokens":60',
'"completion_tokens":40',
'"total_tokens":100',
'"cost_micros":50000',
):
assert key in payload.replace(" ", "")
# ------------------------------------------------------------------ llm_bundle
def test_load_llm_bundle_routes_negative_id_to_yaml_loader() -> None:
async def _run() -> tuple[Any, Any, str | None]:
with (
patch(
"app.tasks.chat.streaming.flows.shared.llm_bundle.load_global_llm_config_by_id",
return_value=None,
),
):
return await load_llm_bundle(
session=AsyncMock(), # type: ignore[arg-type]
config_id=-1,
search_space_id=7,
)
llm, agent_config, error = asyncio.run(_run())
assert llm is None
assert agent_config is None
assert error is not None and "id -1" in error
def test_load_llm_bundle_routes_nonnegative_id_to_db_loader() -> None:
async def _run() -> tuple[Any, Any, str | None]:
with (
patch(
"app.tasks.chat.streaming.flows.shared.llm_bundle.load_agent_config",
new=AsyncMock(return_value=None),
),
):
return await load_llm_bundle(
session=AsyncMock(), # type: ignore[arg-type]
config_id=12,
search_space_id=7,
)
llm, agent_config, error = asyncio.run(_run())
assert llm is None
assert agent_config is None
assert error is not None and "id 12" in error
# ----------------------------------------------------------------- premium quota
def test_needs_premium_quota_requires_user_and_premium_flag() -> None:
class _AgentConfig:
is_premium = True
class _NonPremium:
is_premium = False
assert needs_premium_quota(_AgentConfig(), "user-1") is True # type: ignore[arg-type]
assert needs_premium_quota(_AgentConfig(), None) is False # type: ignore[arg-type]
assert needs_premium_quota(_NonPremium(), "user-1") is False # type: ignore[arg-type]
assert needs_premium_quota(None, "user-1") is False
def test_premium_reservation_dataclass_shape() -> None:
# Sanity: the dataclass exists and carries the fields the orchestrator uses.
r = PremiumReservation(request_id="abc", reserved_micros=100, allowed=True)
assert r.request_id == "abc"
assert r.reserved_micros == 100
assert r.allowed is True
# ----------------------------------------------------------- rate-limit guard
@pytest.mark.parametrize(
"first_event_seen, recovered, requested_id, current_id, expected",
[
(False, False, 0, -1, True),
# Already recovered: no second pass.
(False, True, 0, -1, False),
# User explicitly picked a config: don't silently switch.
(False, False, 5, -1, False),
# Already on a database-backed (positive) id.
(False, False, 0, 7, False),
# User has already seen output: silent rebuild not possible.
(True, False, 0, -1, False),
],
)
def test_can_recover_provider_rate_limit_truth_table(
first_event_seen: bool,
recovered: bool,
requested_id: int,
current_id: int,
expected: bool,
) -> None:
# Use a known rate-limit-shaped exception so the helper's last condition
# is satisfied; the guard only short-circuits to False when one of the
# *other* preconditions fails.
exc = Exception('{"error":{"type":"rate_limit_error","message":"slow"}}')
assert (
can_recover_provider_rate_limit(
exc,
first_event_seen=first_event_seen,
runtime_rate_limit_recovered=recovered,
requested_llm_config_id=requested_id,
current_llm_config_id=current_id,
)
is expected
)
def test_can_recover_provider_rate_limit_rejects_non_rate_limit_exception() -> None:
assert (
can_recover_provider_rate_limit(
ValueError("not a rate limit"),
first_event_seen=False,
runtime_rate_limit_recovered=False,
requested_llm_config_id=0,
current_llm_config_id=-1,
)
is False
)
# --------------------------------------------------------- persistence spawn
def test_spawn_set_ai_responding_bg_noop_without_user_id() -> None:
async def _run() -> set[asyncio.Task]:
background: set[asyncio.Task] = set()
spawn_set_ai_responding_bg(
chat_id=1, user_id=None, background_tasks=background
)
return background
bg = asyncio.run(_run())
assert bg == set()
def test_spawn_persist_user_task_registers_and_self_unregisters() -> None:
async def _run() -> tuple[int, int]:
background: set[asyncio.Task] = set()
with patch(
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_user_turn",
new=AsyncMock(return_value=99),
):
task = spawn_persist_user_task(
chat_id=1,
user_id="u",
turn_id="t",
user_query="hi",
user_image_data_urls=None,
mentioned_documents=None,
background_tasks=background,
)
size_before_await = len(background)
result = await asyncio.shield(task)
# Give the done-callback one event-loop tick to run.
await asyncio.sleep(0)
return size_before_await, result # type: ignore[return-value]
size_before, result = asyncio.run(_run())
assert size_before == 1
assert result == 99
def test_spawn_persist_assistant_shell_task_registers() -> None:
async def _run() -> int | None:
background: set[asyncio.Task] = set()
with patch(
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_assistant_shell",
new=AsyncMock(return_value=42),
):
task = spawn_persist_assistant_shell_task(
chat_id=1,
user_id="u",
turn_id="t",
background_tasks=background,
)
return await asyncio.shield(task)
assert asyncio.run(_run()) == 42
def test_await_persist_task_returns_none_on_failure() -> None:
async def _run() -> int | None:
async def _boom() -> int:
raise RuntimeError("DB down")
task = asyncio.create_task(_boom())
return await await_persist_task(
task,
chat_id=1,
turn_id="t",
log_label="parity-failure",
)
assert asyncio.run(_run()) is None
def test_await_persist_task_returns_none_for_none_input() -> None:
async def _run() -> int | None:
return await await_persist_task(
None,
chat_id=1,
turn_id="t",
log_label="parity-none",
)
assert asyncio.run(_run()) is None