refactor(chat): delete legacy stream_new_chat monolith (cutover complete)

The flows orchestrators (new_chat/resume_chat) are now the sole live path
after the byte-for-byte differential proof, so the monolith and its
monolith-vs-flows parity scaffolding are removed.

- Repoint the last live importer (anonymous_chat_routes) to
  streaming.agent.event_loop.stream_agent_events + shared.stream_result.StreamResult
  (drop-in; the keyword-only fallback-commit params default to inert for anon).
- Repoint e2e launcher patch targets to flows.shared.llm_bundle.
- Repoint helper unit tests (chunk_parts, thinking-step ids, tool-input
  streaming) to their flows homes to preserve coverage.
- Delete the monolith, the contract test, and the parity tests
  (parallel_refactor, stage_1, stage_2, orchestrator_frame) whose sole
  purpose was comparing against the now-removed monolith.

Full suite green (2622 passed, 1 skipped); the two excluded live-app dirs
(document_upload, composio) have a pre-existing, env-gated registration 404
unrelated to this change.
This commit is contained in:
CREDO23 2026-06-04 14:35:45 +02:00
parent b9937cf4b1
commit 5b45f78a16
12 changed files with 25 additions and 5028 deletions

View file

@ -247,11 +247,11 @@ def _patch_llm_bindings() -> None:
fake_create_chat_litellm_from_config,
),
(
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config",
"app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_agent_config",
fake_create_chat_litellm_from_agent_config,
),
(
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_config",
"app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config",
fake_create_chat_litellm_from_config,
),
]

View file

@ -220,11 +220,11 @@ def _patch_llm_bindings() -> None:
fake_create_chat_litellm_from_config,
),
(
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config",
"app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_agent_config",
fake_create_chat_litellm_from_agent_config,
),
(
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_config",
"app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config",
fake_create_chat_litellm_from_config,
),
]

View file

@ -1,457 +0,0 @@
"""Byte-for-byte frame parity: legacy monolith vs refactored flows orchestrators.
The agent-content portion of the stream (`text-*`, tool cards, thinking-step
updates) flows through **shared** code in both implementations
(`stream_output` -> `EventRelay.relay` -> handlers), so it cannot diverge. The
only independently-written part is the *orchestrator glue*: the initial frames,
persistence-handshake frames, error/terminal branches, and final frames.
This module drives BOTH ``stream_new_chat`` implementations (legacy
``app.tasks.chat.stream_new_chat`` and the refactored
``app.tasks.chat.streaming.flows``) through the deterministic glue paths and
asserts the emitted SSE frame sequences are **byte-for-byte identical**. These
are the paths where divergence could hide; the agent-streaming portion is shared
and is covered separately.
Determinism is enforced by:
* freezing ``time.time`` (so ``turn_id = f"{chat_id}:{ms}"`` is stable),
* a deterministic ``uuid`` sequence for the streaming-service id generators,
* stubbing every DB/LLM/agent seam (LLM resolution, persistence, connector,
checkpointer, session) to fixed values.
Cutover gate: when these are green, the live callers can be flipped to the
flows orchestrators.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
import app.services.new_streaming_service as _nss
from app.tasks.chat.stream_new_chat import (
stream_new_chat as old_stream_new_chat,
stream_resume_chat as old_stream_resume_chat,
)
from app.tasks.chat.streaming.flows import (
stream_new_chat as new_stream_new_chat,
stream_resume_chat as new_stream_resume_chat,
)
pytestmark = pytest.mark.unit
_FIXED_EPOCH = 1_700_000_000.0 # -> turn_id "<chat_id>:1700000000000"
# --------------------------------------------------------------------------- #
# Deterministic uuid for the streaming-service id generators
# --------------------------------------------------------------------------- #
class _SeqUUID:
"""Drop-in for the ``uuid`` module used by ``new_streaming_service``.
Only ``uuid4().hex`` is consumed by the id generators. We hand out a
monotonic, zero-padded hex so two runs that emit the same number of ids in
the same order produce identical bytes.
"""
def __init__(self) -> None:
self._n = 0
def reset(self) -> None:
self._n = 0
def uuid4(self) -> SimpleNamespace:
self._n += 1
return SimpleNamespace(hex=f"{self._n:032x}")
_SEQ = _SeqUUID()
# --------------------------------------------------------------------------- #
# Fake session: the orchestrator owns ``async_session_maker()``; for the glue
# paths every real consumer is stubbed, so a no-op session suffices.
# --------------------------------------------------------------------------- #
class _FakeResult:
"""Empty-everything SQLAlchemy ``Result`` stand-in for pre-stream reads."""
def scalars(self) -> "_FakeResult":
return self
def first(self) -> None:
return None
def all(self) -> list[Any]:
return []
def one_or_none(self) -> None:
return None
def scalar_one_or_none(self) -> None:
return None
def scalar(self) -> None:
return None
def fetchall(self) -> list[Any]:
return []
def __iter__(self):
return iter(())
class _FakeSession:
async def commit(self) -> None: # pragma: no cover - trivial
return None
async def rollback(self) -> None: # pragma: no cover - trivial
return None
async def close(self) -> None: # pragma: no cover - trivial
return None
def expunge_all(self) -> None: # pragma: no cover - trivial
return None
def add(self, *a: Any, **k: Any) -> None: # pragma: no cover - trivial
return None
async def flush(self, *a: Any, **k: Any) -> None: # pragma: no cover
return None
async def execute(self, *a: Any, **k: Any) -> _FakeResult:
return _FakeResult()
class _FakeConnectorService:
def __init__(self, *a: Any, **k: Any) -> None:
pass
async def get_connector_by_type(self, *a: Any, **k: Any) -> None:
return None
def _patch(monkeypatch: pytest.MonkeyPatch, target: str, value: Any) -> None:
"""``setattr`` that tolerates a missing attr (binding may be local-import)."""
monkeypatch.setattr(target, value, raising=False)
def _apply_common(
monkeypatch: pytest.MonkeyPatch,
*,
pin_raises: ValueError | None = None,
resolved_id: int = -1,
llm_load_ok: bool = True,
persist_user_id: int | None = 101,
persist_assistant_id: int | None = 102,
) -> None:
"""Patch every glue seam in BOTH implementations to deterministic values."""
# Time -> stable turn_id and any retry_after_at.
monkeypatch.setattr("time.time", lambda: _FIXED_EPOCH)
# Deterministic streaming-service ids.
monkeypatch.setattr(_nss, "uuid", _SEQ)
fake_model = MagicMock(name="scripted_llm")
# --- session ---
for tgt in (
"app.tasks.chat.stream_new_chat.async_session_maker",
"app.tasks.chat.streaming.flows.new_chat.orchestrator.async_session_maker",
"app.tasks.chat.streaming.flows.resume_chat.orchestrator.async_session_maker",
):
_patch(monkeypatch, tgt, _FakeSession)
# --- connector service ---
for tgt in (
"app.tasks.chat.stream_new_chat.ConnectorService",
"app.tasks.chat.streaming.flows.shared.pre_stream_setup.ConnectorService",
):
_patch(monkeypatch, tgt, _FakeConnectorService)
# --- checkpointer ---
for tgt in (
"app.tasks.chat.stream_new_chat.get_checkpointer",
"app.tasks.chat.streaming.flows.shared.pre_stream_setup.get_checkpointer",
):
_patch(monkeypatch, tgt, AsyncMock(return_value=MagicMock(name="checkpointer")))
# --- agent factory (built but never streamed on glue paths) ---
# Resume routing awaits ``agent.aget_state`` before persist, so the fake
# agent exposes async state accessors returning an empty (no-interrupt)
# snapshot. ``astream_events`` is never reached on glue paths.
fake_agent = MagicMock(name="agent")
fake_agent.aget_state = AsyncMock(
return_value=SimpleNamespace(values={}, tasks=[], interrupts=[], next=())
)
fake_agent.aupdate_state = AsyncMock(return_value=None)
agent_factory = AsyncMock(return_value=fake_agent)
for tgt in (
"app.tasks.chat.stream_new_chat.create_multi_agent_chat_deep_agent",
"app.tasks.chat.streaming.flows.new_chat.orchestrator.create_multi_agent_chat_deep_agent",
"app.tasks.chat.streaming.flows.resume_chat.orchestrator.create_multi_agent_chat_deep_agent",
):
_patch(monkeypatch, tgt, agent_factory)
# --- LLM resolution (auto-pin) ---
if pin_raises is not None:
async def _resolver(*a: Any, **k: Any):
raise pin_raises
else:
async def _resolver(*a: Any, **k: Any):
return SimpleNamespace(resolved_llm_config_id=resolved_id)
_patch(monkeypatch, "app.services.auto_model_pin_service.resolve_or_get_pinned_llm_config_id", _resolver)
_patch(monkeypatch, "app.tasks.chat.stream_new_chat.resolve_or_get_pinned_llm_config_id", _resolver)
_patch(
monkeypatch,
"app.tasks.chat.streaming.flows.new_chat.auto_pin.resolve_or_get_pinned_llm_config_id",
_resolver,
)
# --- LLM bundle ---
sentinel_cfg = object() if llm_load_ok else None
_patch(monkeypatch, "app.tasks.chat.stream_new_chat.load_global_llm_config_by_id", lambda cid: sentinel_cfg)
_patch(
monkeypatch,
"app.tasks.chat.streaming.flows.shared.llm_bundle.load_global_llm_config_by_id",
lambda cid: sentinel_cfg,
)
_patch(monkeypatch, "app.tasks.chat.stream_new_chat.create_chat_litellm_from_config", lambda cfg: fake_model)
_patch(
monkeypatch,
"app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config",
lambda cfg: fake_model,
)
# agent_config := None keeps premium + capability gates inert and identical.
from app.agents.shared.llm_config import AgentConfig
monkeypatch.setattr(AgentConfig, "from_yaml_config", staticmethod(lambda cfg: None))
# --- persistence ---
async def _persist_user(*a: Any, **k: Any):
return persist_user_id
async def _persist_assistant(*a: Any, **k: Any):
return persist_assistant_id
async def _finalize(*a: Any, **k: Any):
return None
for mod in (
"app.tasks.chat.persistence",
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn",
):
_patch(monkeypatch, f"{mod}.persist_user_turn", _persist_user)
_patch(monkeypatch, f"{mod}.persist_assistant_shell", _persist_assistant)
# Resume binds ``persist_assistant_shell`` in its own assistant_shell module.
_patch(
monkeypatch,
"app.tasks.chat.streaming.flows.resume_chat.assistant_shell.persist_assistant_shell",
_persist_assistant,
)
_patch(monkeypatch, "app.tasks.chat.persistence.finalize_assistant_turn", _finalize)
# --- collaboration flags ---
async def _noop(*a: Any, **k: Any):
return None
for tgt in (
"app.tasks.chat.stream_new_chat.set_ai_responding",
"app.tasks.chat.stream_new_chat.clear_ai_responding",
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn.set_ai_responding",
"app.services.chat_session_state_service.set_ai_responding",
"app.services.chat_session_state_service.clear_ai_responding",
):
_patch(monkeypatch, tgt, _noop)
async def _collect(genfunc: Any, **kwargs: Any) -> list[str]:
frames: list[str] = []
async for frame in genfunc(**kwargs):
frames.append(frame)
return frames
async def _run_both(kwargs: dict[str, Any]) -> tuple[list[str], list[str]]:
"""Drive both NEW-chat implementations on identical inputs."""
_SEQ.reset()
old = await _collect(old_stream_new_chat, **kwargs)
_SEQ.reset()
new = await _collect(new_stream_new_chat, **kwargs)
return old, new
async def _run_both_resume(kwargs: dict[str, Any]) -> tuple[list[str], list[str]]:
"""Drive both RESUME-chat implementations on identical inputs."""
_SEQ.reset()
old = await _collect(old_stream_resume_chat, **kwargs)
_SEQ.reset()
new = await _collect(new_stream_resume_chat, **kwargs)
return old, new
def _assert_parity(old: list[str], new: list[str]) -> None:
"""Byte-for-byte equality with a readable first-divergence message."""
for i, (a, b) in enumerate(zip(old, new, strict=False)):
assert a == b, f"frame[{i}] differs:\n old={a!r}\n new={b!r}"
assert len(old) == len(new), (
f"frame count differs: old={len(old)} new={len(new)}\n"
f" old tail={old[len(new):]!r}\n new tail={new[len(old):]!r}"
)
assert old[-1].strip() == "data: [DONE]"
# --------------------------------------------------------------------------- #
# NEW-chat scenarios
# --------------------------------------------------------------------------- #
_NEW_KW = dict(user_query="hi", search_space_id=1, chat_id=42, user_id=None)
@pytest.mark.asyncio
async def test_auto_pin_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None:
"""Auto-pin raises -> identical ``[error, DONE]`` from both."""
_apply_common(monkeypatch, pin_raises=ValueError("no eligible config"))
old, new = await _run_both(dict(_NEW_KW))
_assert_parity(old, new)
assert len(old) == 2
assert '"errorCode": "SERVER_ERROR"' in old[0]
@pytest.mark.asyncio
async def test_llm_load_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None:
"""LLM bundle load fails -> identical ``[error, DONE]`` from both."""
_apply_common(monkeypatch, llm_load_ok=False)
old, new = await _run_both(dict(_NEW_KW))
_assert_parity(old, new)
assert len(old) == 2
assert '"errorCode": "SERVER_ERROR"' in old[0]
@pytest.mark.asyncio
async def test_persist_user_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None:
"""User-turn persist returns None.
Exercises the full initial-frame ordering (start, start-step, turn-info,
turn-status busy), the MESSAGE_PERSIST_FAILED error, and final frames.
"""
_apply_common(monkeypatch, persist_user_id=None)
old, new = await _run_both(dict(_NEW_KW))
_assert_parity(old, new)
assert '"type": "start"' in old[0]
assert '"chat_turn_id": "42:1700000000000"' in old[2]
assert any('"errorCode": "MESSAGE_PERSIST_FAILED"' in f for f in old)
assert any('"type": "finish"' in f for f in old)
@pytest.mark.asyncio
async def test_persist_assistant_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None:
"""Assistant-shell persist returns None.
Adds the ``data-user-message-id`` handshake frame ahead of the error.
"""
_apply_common(monkeypatch, persist_user_id=101, persist_assistant_id=None)
old, new = await _run_both(dict(_NEW_KW))
_assert_parity(old, new)
assert any('"data-user-message-id"' in f and '"message_id": 101' in f for f in old)
assert any('"errorCode": "MESSAGE_PERSIST_FAILED"' in f for f in old)
@pytest.mark.asyncio
async def test_prestream_exception_parity(monkeypatch: pytest.MonkeyPatch) -> None:
"""A pre-stream failure routes both through the top-level ``except`` path.
Resolver returns a non-int so ``turn_id`` math / downstream use raises after
the span opens but before initial frames: both must emit the identical
``busy -> error -> idle -> finish-step -> finish -> DONE`` terminal sequence.
"""
async def _bad_resolver(*a: Any, **k: Any):
raise RuntimeError("boom in pre-stream")
_apply_common(monkeypatch)
# Override the resolver with a non-ValueError so the classified early-error
# branches don't catch it -> top-level except path.
for tgt in (
"app.services.auto_model_pin_service.resolve_or_get_pinned_llm_config_id",
"app.tasks.chat.stream_new_chat.resolve_or_get_pinned_llm_config_id",
"app.tasks.chat.streaming.flows.new_chat.auto_pin.resolve_or_get_pinned_llm_config_id",
):
_patch(monkeypatch, tgt, _bad_resolver)
old, new = await _run_both(dict(_NEW_KW))
_assert_parity(old, new)
assert any('"type": "error"' in f for f in old)
# --------------------------------------------------------------------------- #
# RESUME-chat scenarios (no title-generation path -> fully deterministic)
# --------------------------------------------------------------------------- #
_RESUME_KW = dict(chat_id=42, search_space_id=1, decisions=[], user_id=None)
async def _collect_resume_old() -> list[str]:
_SEQ.reset()
return await _collect(old_stream_resume_chat, **dict(_RESUME_KW))
# NOTE: KNOWN, INTENTIONAL DIVERGENCE (flows fixes a latent monolith bug).
#
# In ``stream_resume_chat`` the monolith defines ``_resume_premium_request_id``
# (line ~2363) AFTER the auto-pin / LLM-load early-return points (~2346 / ~2356).
# Its ``finally`` block (line ~2918) reads that variable, so a resume turn whose
# auto-pin raises or whose LLM bundle fails to load crashes with
# ``UnboundLocalError`` instead of emitting a clean terminal-error frame. The
# refactored flows orchestrator does NOT have this bug — it emits the proper
# ``[error, DONE]`` sequence. We assert the divergence explicitly so the cutover
# is a documented behavior IMPROVEMENT rather than a silent change.
@pytest.mark.asyncio
async def test_resume_auto_pin_failure_flows_fixes_monolith_crash(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_apply_common(monkeypatch, pin_raises=ValueError("no eligible config"))
# Monolith: latent UnboundLocalError in the finally clause.
with pytest.raises(UnboundLocalError, match="_resume_premium_request_id"):
await _collect_resume_old()
# Flows: clean terminal error.
_SEQ.reset()
new = await _collect(new_stream_resume_chat, **dict(_RESUME_KW))
assert len(new) == 2
assert new[-1].strip() == "data: [DONE]"
assert '"type": "error"' in new[0]
@pytest.mark.asyncio
async def test_resume_llm_load_failure_flows_fixes_monolith_crash(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_apply_common(monkeypatch, llm_load_ok=False)
with pytest.raises(UnboundLocalError, match="_resume_premium_request_id"):
await _collect_resume_old()
_SEQ.reset()
new = await _collect(new_stream_resume_chat, **dict(_RESUME_KW))
assert len(new) == 2
assert new[-1].strip() == "data: [DONE]"
assert '"type": "error"' in new[0]
@pytest.mark.asyncio
async def test_resume_persist_assistant_failure_parity(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Resume emits NO user-message-id frame; only the assistant handshake path."""
_apply_common(monkeypatch, persist_assistant_id=None)
old, new = await _run_both_resume(dict(_RESUME_KW))
_assert_parity(old, new)
assert not any('"data-user-message-id"' in f for f in old)
assert any('"chat_turn_id": "42:1700000000000"' in f for f in old)

View file

@ -1,584 +0,0 @@
"""Parity gate for the parallel refactor of ``stream_new_chat.py``.
The new tree under ``app.tasks.chat.streaming.flows`` is built side-by-side with
the legacy monolithic ``app.tasks.chat.stream_new_chat`` so we can cut over
atomically. This file pins externally-observable behaviour at module
boundaries so a divergence between the two trees fails loudly *before* the
cutover.
What we verify:
1. **Signature parity** ``stream_new_chat`` / ``stream_resume_chat`` from
the new tree have the same call signature as the originals.
2. **Helper extraction parity** the SRP modules in ``flows/`` produce the
same outputs as the inline code in the legacy file for representative
inputs (initial thinking step, image-capability gate, runtime context,
SSE frame sequences, token-usage frame shape, persistence guards).
3. **Wrapper delegation** wrappers like ``load_llm_bundle`` /
``can_recover_provider_rate_limit`` exist and are addressable.
Delete this file along with ``stream_new_chat.py`` once the cutover is done
(see the parent refactor plan).
"""
from __future__ import annotations
import asyncio
import inspect
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from app.agents.shared.context import SurfSenseContextSchema
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_new_chat import (
stream_new_chat as old_stream_new_chat,
stream_resume_chat as old_stream_resume_chat,
)
from app.tasks.chat.streaming.flows import (
stream_new_chat as new_stream_new_chat,
stream_resume_chat as new_stream_resume_chat,
)
from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import (
build_initial_thinking_step,
)
from app.tasks.chat.streaming.flows.new_chat.llm_capability import (
check_image_input_capability,
)
from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import (
await_persist_task,
spawn_persist_assistant_shell_task,
spawn_persist_user_task,
spawn_set_ai_responding_bg,
)
from app.tasks.chat.streaming.flows.new_chat.runtime_context import (
build_new_chat_runtime_context,
)
from app.tasks.chat.streaming.flows.resume_chat.runtime_context import (
build_resume_chat_runtime_context,
)
from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame
from app.tasks.chat.streaming.flows.shared.first_frames import (
iter_final_frames,
iter_initial_frames,
)
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
from app.tasks.chat.streaming.flows.shared.premium_quota import (
PremiumReservation,
needs_premium_quota,
)
from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import (
can_recover_provider_rate_limit,
)
pytestmark = pytest.mark.unit
# --------------------------------------------------------------------- signature
def _normalize_annotation(ann: Any) -> str:
"""Compare-friendly form for an annotation.
The legacy ``stream_new_chat.py`` does NOT use ``from __future__ import
annotations``, so its annotations are evaluated at import time and come
back as type objects / typing generics. The new tree DOES use it, so its
annotations are PEP-563 strings.
Both reprs describe the same types strip the module prefixes / typing
namespace + the ``<class 'X'>`` wrapper so we compare the canonical
declared form.
"""
if ann is inspect.Signature.empty:
return ""
raw = ann if isinstance(ann, str) else repr(ann)
cleaned = (
raw.replace("typing.", "")
.replace("collections.abc.", "")
.replace("app.db.", "")
.replace("app.agents.shared.filesystem_selection.", "")
.replace("app.agents.shared.context.", "")
)
# Unwrap ``<class 'int'>`` → ``int`` (legacy-side type objects).
if cleaned.startswith("<class '") and cleaned.endswith("'>"):
cleaned = cleaned[len("<class '") : -len("'>")]
return cleaned
def _normalize_sig(sig: inspect.Signature) -> list[tuple[str, Any, str]]:
return [
(p.name, p.default, _normalize_annotation(p.annotation))
for p in sig.parameters.values()
]
def test_stream_new_chat_signature_matches_legacy() -> None:
old = inspect.signature(old_stream_new_chat)
new = inspect.signature(new_stream_new_chat)
assert _normalize_sig(new) == _normalize_sig(old)
assert _normalize_annotation(new.return_annotation) == _normalize_annotation(
old.return_annotation
)
def test_stream_resume_chat_signature_matches_legacy() -> None:
old = inspect.signature(old_stream_resume_chat)
new = inspect.signature(new_stream_resume_chat)
assert _normalize_sig(new) == _normalize_sig(old)
assert _normalize_annotation(new.return_annotation) == _normalize_annotation(
old.return_annotation
)
def test_orchestrators_are_async_generator_functions() -> None:
assert inspect.isasyncgenfunction(new_stream_new_chat)
assert inspect.isasyncgenfunction(new_stream_resume_chat)
# ------------------------------------------------------------ initial thinking
@pytest.mark.parametrize(
"user_query, image_urls, expected_title, expected_action",
[
("hello world", None, "Understanding your request", "Processing"),
(
"",
["data:image/png;base64,AAA"],
"Understanding your request",
"Processing",
),
("", None, "Understanding your request", "Processing"),
],
)
def test_initial_thinking_step_branches(
user_query: str,
image_urls: list[str] | None,
expected_title: str,
expected_action: str,
) -> None:
step = build_initial_thinking_step(
user_query=user_query,
user_image_data_urls=image_urls,
)
assert step.step_id == "thinking-1"
assert step.title == expected_title
assert len(step.items) == 1
assert step.items[0].startswith(f"{expected_action}: ")
def test_initial_thinking_step_truncates_long_query() -> None:
long_query = "x" * 200
step = build_initial_thinking_step(
user_query=long_query,
user_image_data_urls=None,
)
# 80-char truncation + ellipsis, sandwiched after "Processing: ".
assert "..." in step.items[0]
item = step.items[0]
payload = item[len("Processing: ") :]
assert payload.startswith("x" * 80) and payload.endswith("...")
# ------------------------------------------------------------ capability gate
def test_image_capability_passes_without_images() -> None:
assert (
check_image_input_capability(user_image_data_urls=None, agent_config=None)
is None
)
def test_image_capability_passes_when_capability_unknown() -> None:
"""Unknown / unmapped models are not blocked — only models LiteLLM has
*explicitly* marked text-only trip the gate."""
class _AgentConfig:
provider = "openrouter"
model_name = "unknown-mystery-model"
custom_provider = None
config_name = "Unknown"
litellm_params: dict[str, Any] = {}
with patch(
"app.services.provider_capabilities.is_known_text_only_chat_model",
return_value=False,
):
assert (
check_image_input_capability(
user_image_data_urls=["data:image/png;base64,AAA"],
agent_config=_AgentConfig(), # type: ignore[arg-type]
)
is None
)
def test_image_capability_blocks_known_text_only_models() -> None:
class _AgentConfig:
provider = "openai"
model_name = "gpt-3.5-turbo"
custom_provider = None
config_name = "GPT-3.5"
litellm_params: dict[str, Any] = {"base_model": "gpt-3.5-turbo"}
with patch(
"app.services.provider_capabilities.is_known_text_only_chat_model",
return_value=True,
):
result = check_image_input_capability(
user_image_data_urls=["data:image/png;base64,AAA"],
agent_config=_AgentConfig(), # type: ignore[arg-type]
)
assert result is not None
message, error_code = result
assert error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
assert "GPT-3.5" in message
# ---------------------------------------------------------------- runtime ctx
def test_new_chat_runtime_context_prefers_accepted_folder_ids() -> None:
"""Post-resolve accepted folder ids win over the raw requested ids."""
ctx = build_new_chat_runtime_context(
search_space_id=7,
mentioned_document_ids=[1, 2],
accepted_folder_ids=[10],
mentioned_folder_ids=[20, 30],
mentioned_connector_ids=None,
mentioned_connectors=None,
request_id="req",
turn_id="t1",
)
assert isinstance(ctx, SurfSenseContextSchema)
assert ctx.search_space_id == 7
assert list(ctx.mentioned_document_ids) == [1, 2]
assert list(ctx.mentioned_folder_ids) == [10]
assert ctx.request_id == "req"
assert ctx.turn_id == "t1"
def test_new_chat_runtime_context_falls_back_to_mentioned_folder_ids() -> None:
"""With no accepted ids, the raw requested folder ids flow through."""
ctx = build_new_chat_runtime_context(
search_space_id=7,
mentioned_document_ids=None,
accepted_folder_ids=[],
mentioned_folder_ids=[20, 30],
mentioned_connector_ids=None,
mentioned_connectors=None,
request_id=None,
turn_id="t2",
)
assert list(ctx.mentioned_folder_ids) == [20, 30]
def test_new_chat_runtime_context_propagates_connector_mentions() -> None:
"""@-selected connector ids/accounts ride onto the runtime context schema.
Parity with the legacy ``stream_new_chat`` runtime context, which set both
``mentioned_connector_ids`` and ``mentioned_connectors`` on the schema.
"""
connectors = [{"id": 5, "connector_type": "SLACK_CONNECTOR", "title": "acme"}]
ctx = build_new_chat_runtime_context(
search_space_id=7,
mentioned_document_ids=None,
accepted_folder_ids=[],
mentioned_folder_ids=None,
mentioned_connector_ids=[5],
mentioned_connectors=connectors,
request_id=None,
turn_id="t3",
)
assert list(ctx.mentioned_connector_ids) == [5]
assert list(ctx.mentioned_connectors) == connectors
def test_resume_chat_runtime_context_empty_mention_lists() -> None:
ctx = build_resume_chat_runtime_context(
search_space_id=42, request_id="req-r", turn_id="t-r"
)
assert ctx.search_space_id == 42
assert ctx.request_id == "req-r"
assert ctx.turn_id == "t-r"
# ---------------------------------------------------------------- SSE frames
def test_iter_initial_frames_emits_canonical_sequence() -> None:
svc = VercelStreamingService()
frames = list(iter_initial_frames(svc, turn_id="42:1700000000000"))
# Exactly 4 frames: message_start, start_step, turn-info (turn_id), turn-status (busy).
assert len(frames) == 4
assert "42:1700000000000" in frames[2]
assert '"status":"busy"' in frames[3] or '"status": "busy"' in frames[3]
def test_iter_final_frames_emits_idle_then_finish_done() -> None:
svc = VercelStreamingService()
frames = list(iter_final_frames(svc))
assert len(frames) == 4
assert '"status":"idle"' in frames[0] or '"status": "idle"' in frames[0]
# ----------------------------------------------------------- token usage frame
class _FakeAccumulator:
"""Minimal stand-in covering only the fields ``iter_token_usage_frame`` reads."""
def __init__(self, summary: Any = None) -> None:
self._summary = summary
self.calls = [1, 2, 3]
self.grand_total = 100
self.total_cost_micros = 50_000
self.total_prompt_tokens = 60
self.total_completion_tokens = 40
def per_message_summary(self) -> Any:
return self._summary
def serialized_calls(self) -> list[Any]:
return list(self.calls)
def test_token_usage_frame_skipped_when_no_summary() -> None:
svc = VercelStreamingService()
frames = list(
iter_token_usage_frame(
svc,
accumulator=_FakeAccumulator(summary=None), # type: ignore[arg-type]
log_label="parity-empty",
)
)
assert frames == []
def test_token_usage_frame_emitted_when_summary_present() -> None:
svc = VercelStreamingService()
frames = list(
iter_token_usage_frame(
svc,
accumulator=_FakeAccumulator(summary=[{"m": "x", "t": 100}]), # type: ignore[arg-type]
log_label="parity-populated",
)
)
assert len(frames) == 1
# Field shape on the wire is fixed by the FE; assert each surfaces.
payload = frames[0]
for key in (
'"prompt_tokens":60',
'"completion_tokens":40',
'"total_tokens":100',
'"cost_micros":50000',
):
assert key in payload.replace(" ", "")
# ------------------------------------------------------------------ llm_bundle
def test_load_llm_bundle_routes_negative_id_to_yaml_loader() -> None:
async def _run() -> tuple[Any, Any, str | None]:
with (
patch(
"app.tasks.chat.streaming.flows.shared.llm_bundle.load_global_llm_config_by_id",
return_value=None,
),
):
return await load_llm_bundle(
session=AsyncMock(), # type: ignore[arg-type]
config_id=-1,
search_space_id=7,
)
llm, agent_config, error = asyncio.run(_run())
assert llm is None
assert agent_config is None
assert error is not None and "id -1" in error
def test_load_llm_bundle_routes_nonnegative_id_to_db_loader() -> None:
async def _run() -> tuple[Any, Any, str | None]:
with (
patch(
"app.tasks.chat.streaming.flows.shared.llm_bundle.load_agent_config",
new=AsyncMock(return_value=None),
),
):
return await load_llm_bundle(
session=AsyncMock(), # type: ignore[arg-type]
config_id=12,
search_space_id=7,
)
llm, agent_config, error = asyncio.run(_run())
assert llm is None
assert agent_config is None
assert error is not None and "id 12" in error
# ----------------------------------------------------------------- premium quota
def test_needs_premium_quota_requires_user_and_premium_flag() -> None:
class _AgentConfig:
is_premium = True
class _NonPremium:
is_premium = False
assert needs_premium_quota(_AgentConfig(), "user-1") is True # type: ignore[arg-type]
assert needs_premium_quota(_AgentConfig(), None) is False # type: ignore[arg-type]
assert needs_premium_quota(_NonPremium(), "user-1") is False # type: ignore[arg-type]
assert needs_premium_quota(None, "user-1") is False
def test_premium_reservation_dataclass_shape() -> None:
# Sanity: the dataclass exists and carries the fields the orchestrator uses.
r = PremiumReservation(request_id="abc", reserved_micros=100, allowed=True)
assert r.request_id == "abc"
assert r.reserved_micros == 100
assert r.allowed is True
# ----------------------------------------------------------- rate-limit guard
@pytest.mark.parametrize(
"first_event_seen, recovered, requested_id, current_id, expected",
[
(False, False, 0, -1, True),
# Already recovered: no second pass.
(False, True, 0, -1, False),
# User explicitly picked a config: don't silently switch.
(False, False, 5, -1, False),
# Already on a database-backed (positive) id.
(False, False, 0, 7, False),
# User has already seen output: silent rebuild not possible.
(True, False, 0, -1, False),
],
)
def test_can_recover_provider_rate_limit_truth_table(
first_event_seen: bool,
recovered: bool,
requested_id: int,
current_id: int,
expected: bool,
) -> None:
# Use a known rate-limit-shaped exception so the helper's last condition
# is satisfied; the guard only short-circuits to False when one of the
# *other* preconditions fails.
exc = Exception('{"error":{"type":"rate_limit_error","message":"slow"}}')
assert (
can_recover_provider_rate_limit(
exc,
first_event_seen=first_event_seen,
runtime_rate_limit_recovered=recovered,
requested_llm_config_id=requested_id,
current_llm_config_id=current_id,
)
is expected
)
def test_can_recover_provider_rate_limit_rejects_non_rate_limit_exception() -> None:
assert (
can_recover_provider_rate_limit(
ValueError("not a rate limit"),
first_event_seen=False,
runtime_rate_limit_recovered=False,
requested_llm_config_id=0,
current_llm_config_id=-1,
)
is False
)
# --------------------------------------------------------- persistence spawn
def test_spawn_set_ai_responding_bg_noop_without_user_id() -> None:
async def _run() -> set[asyncio.Task]:
background: set[asyncio.Task] = set()
spawn_set_ai_responding_bg(chat_id=1, user_id=None, background_tasks=background)
return background
bg = asyncio.run(_run())
assert bg == set()
def test_spawn_persist_user_task_registers_and_self_unregisters() -> None:
async def _run() -> tuple[int, int]:
background: set[asyncio.Task] = set()
with patch(
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_user_turn",
new=AsyncMock(return_value=99),
):
task = spawn_persist_user_task(
chat_id=1,
user_id="u",
turn_id="t",
user_query="hi",
user_image_data_urls=None,
mentioned_documents=None,
background_tasks=background,
)
size_before_await = len(background)
result = await asyncio.shield(task)
# Give the done-callback one event-loop tick to run.
await asyncio.sleep(0)
return size_before_await, result # type: ignore[return-value]
size_before, result = asyncio.run(_run())
assert size_before == 1
assert result == 99
def test_spawn_persist_assistant_shell_task_registers() -> None:
async def _run() -> int | None:
background: set[asyncio.Task] = set()
with patch(
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_assistant_shell",
new=AsyncMock(return_value=42),
):
task = spawn_persist_assistant_shell_task(
chat_id=1,
user_id="u",
turn_id="t",
background_tasks=background,
)
return await asyncio.shield(task)
assert asyncio.run(_run()) == 42
def test_await_persist_task_returns_none_on_failure() -> None:
async def _run() -> int | None:
async def _boom() -> int:
raise RuntimeError("DB down")
task = asyncio.create_task(_boom())
return await await_persist_task(
task,
chat_id=1,
turn_id="t",
log_label="parity-failure",
)
assert asyncio.run(_run()) is None
def test_await_persist_task_returns_none_for_none_input() -> None:
async def _run() -> int | None:
return await await_persist_task(
None,
chat_id=1,
turn_id="t",
log_label="parity-none",
)
assert asyncio.run(_run()) is None

View file

@ -1,240 +0,0 @@
"""Pin Stage 1 extractions as faithful copies of the old helpers.
Extractions under ``app.tasks.chat.streaming`` are compared to
``app.tasks.chat.stream_new_chat`` helpers.
For each Stage 1 extraction we assert the new function returns the same
output as the old one for a representative input set. The moment the
two diverge - intentionally or otherwise - this file fails loudly so
the divergence is reviewed rather than shipped silently.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
import pytest
from app.agents.shared.errors import BusyError
from app.agents.shared.middleware.busy_mutex import request_cancel, reset_cancel
from app.tasks.chat.stream_new_chat import (
_classify_stream_exception as old_classify,
_emit_stream_terminal_error as old_emit_terminal_error,
_extract_chunk_parts as old_extract_chunk_parts,
_extract_resolved_file_path as old_extract_resolved_file_path,
_tool_output_has_error as old_tool_output_has_error,
_tool_output_to_text as old_tool_output_to_text,
)
from app.tasks.chat.streaming.errors.classifier import (
classify_stream_exception as new_classify,
)
from app.tasks.chat.streaming.errors.emitter import (
emit_stream_terminal_error as new_emit_terminal_error,
)
from app.tasks.chat.streaming.helpers.chunk_parts import (
extract_chunk_parts as new_extract_chunk_parts,
)
from app.tasks.chat.streaming.helpers.tool_output import (
extract_resolved_file_path as new_extract_resolved_file_path,
tool_output_has_error as new_tool_output_has_error,
tool_output_to_text as new_tool_output_to_text,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------- chunk parts
@dataclass
class _Chunk:
content: Any = ""
additional_kwargs: dict[str, Any] = field(default_factory=dict)
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
_CHUNK_CASES: list[Any] = [
None,
_Chunk(content=""),
_Chunk(content="hello"),
_Chunk(content=42), # invalid type, defensively coerced to empty
_Chunk(
content=[
{"type": "text", "text": "Hello "},
{"type": "text", "text": "world"},
]
),
_Chunk(
content=[
{"type": "reasoning", "reasoning": "hmm "},
{"type": "reasoning", "text": "still"},
{"type": "text", "text": "answer"},
]
),
_Chunk(
content=[
{"type": "tool_call_chunk", "id": "c1", "name": "x", "args": "{"},
{"type": "tool_use", "id": "c2", "name": "y"},
{"type": "image_url", "url": "ignored"},
]
),
_Chunk(
content="visible",
additional_kwargs={"reasoning_content": "private"},
),
_Chunk(
tool_call_chunks=[
{"id": None, "name": None, "args": '{"a":1}', "index": 0},
{"id": "c", "name": "n", "args": "}", "index": 0},
]
),
_Chunk(
content=[{"type": "tool_call_chunk", "id": "from-block", "name": "x"}],
tool_call_chunks=[{"id": "from-attr", "name": "y"}],
),
]
@pytest.mark.parametrize("chunk", _CHUNK_CASES)
def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None:
assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk)
# ----------------------------------------------------------- error classifier
def _classify_cases() -> list[Exception]:
"""Inputs that the FE depends on being mapped to specific error codes."""
return [
Exception("totally generic error"),
Exception('{"error":{"type":"rate_limit_error","message":"slow down"}}'),
Exception(
'OpenrouterException - {"error":{"message":"Provider returned error",'
'"code":429}}'
),
BusyError(request_id="thread-busy-parity"),
Exception("Thread is busy with another request"),
]
@pytest.mark.parametrize("exc", _classify_cases())
def test_classify_stream_exception_matches_old_implementation(
exc: Exception,
) -> None:
new = new_classify(exc, flow_label="parity-test")
old = old_classify(exc, flow_label="parity-test")
# Strip the wall-clock retry timestamp before comparing — both
# implementations call ``time.time()`` independently and the call
# order is enough to differ by 1 ms in practice. Every other field
# in the tuple must match exactly.
new_extra = dict(new[5]) if isinstance(new[5], dict) else new[5]
old_extra = dict(old[5]) if isinstance(old[5], dict) else old[5]
if isinstance(new_extra, dict) and isinstance(old_extra, dict):
new_extra.pop("retry_after_at", None)
old_extra.pop("retry_after_at", None)
assert new[:5] == old[:5]
assert new_extra == old_extra
def test_classify_turn_cancelling_branch_parity() -> None:
"""The TURN_CANCELLING branch reads cancel state for the busy thread id;
both implementations must agree on retry-window semantics, not just the
plain THREAD_BUSY code."""
thread_id = "parity-cancelling-thread"
reset_cancel(thread_id)
request_cancel(thread_id)
exc = BusyError(request_id=thread_id)
new = new_classify(exc, flow_label="parity-test")
old = old_classify(exc, flow_label="parity-test")
assert new[0] == old[0] == "thread_busy"
assert new[1] == old[1] == "TURN_CANCELLING"
assert isinstance(new[5], dict) and isinstance(old[5], dict)
assert new[5]["retry_after_ms"] == old[5]["retry_after_ms"]
# ------------------------------------------------------------ terminal emitter
class _FakeStreamingService:
"""Duck-types ``format_error`` for both old and new emitters."""
def __init__(self) -> None:
self.calls: list[dict[str, Any]] = []
def format_error(
self, message: str, *, error_code: str, extra: dict[str, Any] | None = None
) -> str:
self.calls.append(
{"message": message, "error_code": error_code, "extra": extra}
)
return f'data: {{"type":"error","errorText":"{message}"}}\n\n'
def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None:
"""The new emitter must produce the same SSE frame and log the same
structured payload as the old one for the same arguments."""
args: dict[str, Any] = {
"flow": "new",
"request_id": "req-parity",
"thread_id": 7,
"search_space_id": 9,
"user_id": "user-parity",
"message": "boom",
"error_kind": "server_error",
"error_code": "SERVER_ERROR",
"severity": "error",
"is_expected": False,
"extra": {"foo": "bar"},
}
new_svc = _FakeStreamingService()
old_svc = _FakeStreamingService()
with caplog.at_level(logging.ERROR):
new_frame = new_emit_terminal_error(streaming_service=new_svc, **args)
old_frame = old_emit_terminal_error(streaming_service=old_svc, **args)
assert new_frame == old_frame
assert new_svc.calls == old_svc.calls
chat_error_records = [
r for r in caplog.records if "[chat_stream_error]" in r.message
]
# One log line per emit call (two emits -> two records).
assert len(chat_error_records) == 2
# ---------------------------------------------------------------- tool output
def test_tool_output_helpers_match_old_implementation() -> None:
samples: list[Any] = [
{"result": "ok"},
{"error": "bad"},
{"result": "Error: x"},
"Error: plain",
"fine",
{"nested": {"a": 1}},
]
for s in samples:
assert new_tool_output_to_text(s) == old_tool_output_to_text(s)
assert new_tool_output_has_error(s) == old_tool_output_has_error(s)
assert new_extract_resolved_file_path(
tool_name="write_file",
tool_output={"path": " /tmp/x "},
tool_input=None,
) == old_extract_resolved_file_path(
tool_name="write_file",
tool_output={"path": " /tmp/x "},
tool_input=None,
)
assert new_extract_resolved_file_path(
tool_name="write_file",
tool_output={},
tool_input={"file_path": " /fallback "},
) == old_extract_resolved_file_path(
tool_name="write_file",
tool_output={},
tool_input={"file_path": " /fallback "},
)

View file

@ -1,241 +0,0 @@
"""Parity tests for Stage 2 extractions (tool matching, thinking step, custom events)."""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from app.tasks.chat.stream_new_chat import _legacy_match_lc_id as old_legacy_match
from app.tasks.chat.streaming.handlers.custom_events import (
handle_action_log,
handle_action_log_updated,
handle_document_created,
handle_report_progress,
)
from app.tasks.chat.streaming.helpers.tool_call_matching import (
match_buffered_langchain_tool_call_id as new_legacy_match,
)
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
from app.tasks.chat.streaming.relay.thinking_step_completion import (
complete_active_thinking_step,
)
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
pytestmark = pytest.mark.unit
def _copy_chunk_buffer(raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
return [dict(x) for x in raw]
def test_legacy_tool_call_match_matches_old_implementation() -> None:
cases: list[tuple[list[dict[str, Any]], str, str, dict[str, str]]] = [
(
[
{"name": "write_file", "id": "lc-a"},
{"name": "other", "id": "lc-b"},
],
"write_file",
"run-1",
{},
),
(
[{"name": "x", "id": None}, {"name": "y", "id": "lc-fallback"}],
"write_file",
"run-2",
{},
),
([{"name": "no_id"}], "write_file", "run-3", {}),
]
for chunks_template, tool_name, run_id, lc_map_seed in cases:
old_chunks = _copy_chunk_buffer(chunks_template)
new_chunks = _copy_chunk_buffer(chunks_template)
old_map = dict(lc_map_seed)
new_map = dict(lc_map_seed)
old_out = old_legacy_match(old_chunks, tool_name, run_id, old_map)
new_out = new_legacy_match(new_chunks, tool_name, run_id, new_map)
assert new_out == old_out
assert new_chunks == old_chunks
assert new_map == old_map
def test_emit_thinking_step_frame_invokes_builder_before_service() -> None:
order: list[str] = []
builder = MagicMock()
def on_ts(*args: Any, **kwargs: Any) -> None:
order.append("builder")
builder.on_thinking_step.side_effect = on_ts
svc = MagicMock()
def fmt(**kwargs: Any) -> str:
order.append("service")
return "frame"
svc.format_thinking_step.side_effect = fmt
out = emit_thinking_step_frame(
streaming_service=svc,
content_builder=builder,
step_id="thinking-1",
title="Working",
status="in_progress",
items=["a"],
)
assert out == "frame"
assert order == ["builder", "service"]
builder.on_thinking_step.assert_called_once()
svc.format_thinking_step.assert_called_once()
def test_emit_thinking_step_frame_skips_builder_when_none() -> None:
svc = MagicMock(return_value="x")
svc.format_thinking_step.return_value = "frame"
assert (
emit_thinking_step_frame(
streaming_service=svc,
content_builder=None,
step_id="s",
title="t",
)
== "frame"
)
svc.format_thinking_step.assert_called_once()
def test_complete_active_thinking_step_mirrors_closure_semantics() -> None:
svc = MagicMock()
svc.format_thinking_step.return_value = "done-frame"
completed: set[str] = set()
relay_state = AgentEventRelayState.for_invocation()
frame, new_id = complete_active_thinking_step(
state=relay_state,
streaming_service=svc,
content_builder=None,
last_active_step_id="thinking-1",
last_active_step_title="T",
last_active_step_items=["x"],
completed_step_ids=completed,
)
assert frame == "done-frame"
assert new_id is None
assert "thinking-1" in completed
frame2, id2 = complete_active_thinking_step(
state=relay_state,
streaming_service=svc,
content_builder=None,
last_active_step_id="thinking-1",
last_active_step_title="T",
last_active_step_items=[],
completed_step_ids=completed,
)
assert frame2 is None
assert id2 == "thinking-1"
def test_agent_event_relay_state_factory_matches_counter_rule() -> None:
s0 = AgentEventRelayState.for_invocation()
assert s0.thinking_step_counter == 0
assert s0.last_active_step_id is None
s1 = AgentEventRelayState.for_invocation(
initial_step_id="thinking-resume-1",
initial_step_title="Inherited",
initial_step_items=["Topic: X"],
)
assert s1.thinking_step_counter == 1
assert s1.last_active_step_id == "thinking-resume-1"
assert s1.next_thinking_step_id("thinking") == "thinking-2"
@pytest.mark.parametrize(
("phase", "message", "start_items", "expected_tail"),
[
(
"revising_section",
"progress line",
["Topic: Foo", "Modifying bar", "stale..."],
["Topic: Foo", "Modifying bar", "progress line"],
),
(
"other",
"phase msg",
["Topic: Foo", "old line"],
["Topic: Foo", "phase msg"],
),
],
)
def test_report_progress_items_match_reference(
phase: str,
message: str,
start_items: list[str],
expected_tail: list[str],
) -> None:
svc = MagicMock()
svc.format_thinking_step.return_value = "sse"
items = list(start_items)
frame, new_items = handle_report_progress(
{"message": message, "phase": phase},
last_active_step_id="step-1",
last_active_step_title="Report",
last_active_step_items=items,
streaming_service=svc,
content_builder=None,
)
assert frame == "sse"
assert new_items == expected_tail
kwargs = svc.format_thinking_step.call_args.kwargs
assert kwargs["items"] == expected_tail
def test_report_progress_noop_when_missing_message_or_step() -> None:
svc = MagicMock()
items = ["Topic: A"]
f1, i1 = handle_report_progress(
{"message": "", "phase": "x"},
last_active_step_id="s",
last_active_step_title="t",
last_active_step_items=items,
streaming_service=svc,
content_builder=None,
)
assert f1 is None and i1 is items
f2, i2 = handle_report_progress(
{"message": "m", "phase": "x"},
last_active_step_id=None,
last_active_step_title="t",
last_active_step_items=items,
streaming_service=svc,
content_builder=None,
)
assert f2 is None and i2 is items
def test_document_action_handlers_match_format_data_guards() -> None:
svc = MagicMock()
svc.format_data.return_value = "data-frame"
assert handle_document_created({}, streaming_service=svc) is None
assert handle_document_created({"id": 0}, streaming_service=svc) is None
handle_document_created({"id": 42, "title": "x"}, streaming_service=svc)
svc.format_data.assert_called_with(
"documents-updated", {"action": "created", "document": {"id": 42, "title": "x"}}
)
svc.reset_mock()
assert handle_action_log({"id": None}, streaming_service=svc) is None
handle_action_log({"id": 1}, streaming_service=svc)
svc.format_data.assert_called_once_with("action-log", {"id": 1})
svc.reset_mock()
assert handle_action_log_updated({"id": None}, streaming_service=svc) is None
handle_action_log_updated({"id": 2}, streaming_service=svc)
svc.format_data.assert_called_once_with("action-log-updated", {"id": 2})

View file

@ -1,4 +1,4 @@
"""Unit tests for ``stream_new_chat._extract_chunk_parts``.
"""Unit tests for ``streaming.helpers.chunk_parts.extract_chunk_parts``.
Earlier versions only handled ``isinstance(chunk.content, str)`` and
silently dropped every other shape (Anthropic typed-block lists,
@ -14,7 +14,9 @@ from typing import Any
import pytest
from app.tasks.chat.stream_new_chat import _extract_chunk_parts
from app.tasks.chat.streaming.helpers.chunk_parts import (
extract_chunk_parts as _extract_chunk_parts,
)
@dataclass

View file

@ -7,7 +7,7 @@ constructs a fresh :class:`AgentEventRelayState` with
``thinking_step_counter=0``), React renders sibling timeline rows with the
same key the warning the user reported in production.
The contract this module pins: each ``_stream_agent_events`` invocation must
The contract this module pins: each ``stream_agent_events`` invocation must
receive a ``step_prefix`` that is unique within the thread (we salt with the
per-turn ``turn_id``), so the resulting step IDs across consecutive turns
are always disjoint.
@ -23,10 +23,12 @@ from typing import Any
import pytest
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_new_chat import (
StreamResult,
_resume_step_prefix,
_stream_agent_events,
from app.tasks.chat.streaming.agent.event_loop import (
stream_agent_events as _stream_agent_events,
)
from app.tasks.chat.streaming.shared.stream_result import StreamResult
from app.tasks.chat.streaming.shared.utils import (
resume_step_prefix as _resume_step_prefix,
)
pytestmark = pytest.mark.unit

View file

@ -1,6 +1,6 @@
"""Unit tests for live tool-call argument streaming.
Pins the wire format that ``_stream_agent_events`` emits:
Pins the wire format that ``stream_agent_events`` emits:
``tool-input-start`` ``tool-input-delta``... ``tool-input-available``
``tool-output-available``, keyed consistently with LangChain ``tool_call.id``
when the model streams indexed chunks.
@ -20,11 +20,13 @@ from typing import Any
import pytest
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_new_chat import (
StreamResult,
_legacy_match_lc_id,
_stream_agent_events,
from app.tasks.chat.streaming.agent.event_loop import (
stream_agent_events as _stream_agent_events,
)
from app.tasks.chat.streaming.helpers.tool_call_matching import (
match_buffered_langchain_tool_call_id as _legacy_match_lc_id,
)
from app.tasks.chat.streaming.shared.stream_result import StreamResult
pytestmark = pytest.mark.unit

View file

@ -1,438 +0,0 @@
import inspect
import json
import logging
import re
from pathlib import Path
import pytest
import app.tasks.chat.stream_new_chat as stream_new_chat_module
from app.agents.shared.errors import BusyError
from app.agents.shared.middleware.busy_mutex import request_cancel, reset_cancel
from app.tasks.chat.stream_new_chat import (
StreamResult,
_classify_stream_exception,
_contract_enforcement_active,
_evaluate_file_contract_outcome,
_extract_resolved_file_path,
_log_chat_stream_error,
_tool_output_has_error,
)
pytestmark = pytest.mark.unit
def test_tool_output_error_detection():
assert _tool_output_has_error("Error: failed to write file")
assert _tool_output_has_error({"error": "boom"})
assert _tool_output_has_error({"result": "Error: disk is full"})
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
def test_extract_resolved_file_path_prefers_structured_path():
assert (
_extract_resolved_file_path(
tool_name="write_file",
tool_output={"status": "completed", "path": "/docs/note.md"},
tool_input=None,
)
== "/docs/note.md"
)
def test_extract_resolved_file_path_falls_back_to_tool_input():
assert (
_extract_resolved_file_path(
tool_name="edit_file",
tool_output={"status": "completed", "result": "updated"},
tool_input={"file_path": "/docs/edited.md"},
)
== "/docs/edited.md"
)
def test_extract_resolved_file_path_does_not_parse_result_text():
assert (
_extract_resolved_file_path(
tool_name="write_file",
tool_output={"result": "Updated file /docs/from-text.md"},
tool_input=None,
)
is None
)
def test_file_write_contract_outcome_reasons():
result = StreamResult(intent_detected="file_write")
passed, reason = _evaluate_file_contract_outcome(result)
assert not passed
assert reason == "no_write_attempt"
result.write_attempted = True
passed, reason = _evaluate_file_contract_outcome(result)
assert not passed
assert reason == "write_failed"
result.write_succeeded = True
passed, reason = _evaluate_file_contract_outcome(result)
assert not passed
assert reason == "verification_failed"
result.verification_succeeded = True
passed, reason = _evaluate_file_contract_outcome(result)
assert passed
assert reason == ""
def test_contract_enforcement_local_only():
result = StreamResult(filesystem_mode="desktop_local_folder")
assert _contract_enforcement_active(result)
result.filesystem_mode = "cloud"
assert not _contract_enforcement_active(result)
def _extract_chat_stream_payload(record_message: str) -> dict:
prefix = "[chat_stream_error] "
assert record_message.startswith(prefix)
return json.loads(record_message[len(prefix) :])
def test_unified_chat_stream_error_log_schema(caplog):
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
_log_chat_stream_error(
flow="new",
error_kind="server_error",
error_code="SERVER_ERROR",
severity="warn",
is_expected=False,
request_id="req-123",
thread_id=101,
search_space_id=202,
user_id="user-1",
message="Error during chat: boom",
)
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
payload = _extract_chat_stream_payload(record.message)
required_keys = {
"event",
"flow",
"error_kind",
"error_code",
"severity",
"is_expected",
"request_id",
"thread_id",
"search_space_id",
"user_id",
"message",
}
assert required_keys.issubset(payload.keys())
assert payload["event"] == "chat_stream_error"
assert payload["flow"] == "new"
assert payload["error_code"] == "SERVER_ERROR"
def test_premium_quota_uses_unified_chat_stream_log_shape(caplog):
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
_log_chat_stream_error(
flow="resume",
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
request_id="req-premium",
thread_id=303,
search_space_id=404,
user_id="user-2",
message="Buy more tokens to continue with this model, or switch to a free model",
extra={"auto_fallback": False},
)
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
payload = _extract_chat_stream_payload(record.message)
assert payload["event"] == "chat_stream_error"
assert payload["error_kind"] == "premium_quota_exhausted"
assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED"
assert payload["flow"] == "resume"
assert payload["is_expected"] is True
assert payload["auto_fallback"] is False
def test_stream_error_emission_keeps_machine_error_codes():
source = inspect.getsource(stream_new_chat_module)
format_error_calls = re.findall(r"format_error\(", source)
emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source))
# All stream paths should route through one shared terminal error emitter.
assert len(format_error_calls) == 1
assert {
"PREMIUM_QUOTA_EXHAUSTED",
"SERVER_ERROR",
}.issubset(emitted_error_codes)
assert 'flow: Literal["new", "regenerate"] = "new"' in source
assert "_emit_stream_terminal_error" in source
assert "flow=flow" in source
assert 'flow="resume"' in source
def test_stream_exception_classifies_rate_limited():
exc = Exception(
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
)
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "rate_limited"
assert code == "RATE_LIMITED"
assert severity == "warn"
assert is_expected is True
assert "temporarily rate-limited" in user_message
assert extra is None
def test_stream_exception_classifies_openrouter_429_payload():
exc = Exception(
'OpenrouterException - {"error":{"message":"Provider returned error","code":429,'
'"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}'
)
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "rate_limited"
assert code == "RATE_LIMITED"
assert severity == "warn"
assert is_expected is True
assert "temporarily rate-limited" in user_message
assert extra is None
def test_stream_exception_classifies_thread_busy():
exc = BusyError(request_id="thread-123")
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
assert code == "THREAD_BUSY"
assert severity == "warn"
assert is_expected is True
assert "still finishing for this thread" in user_message
assert extra is None
def test_stream_exception_classifies_thread_busy_from_message():
exc = Exception("Thread is busy with another request")
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
assert code == "THREAD_BUSY"
assert severity == "warn"
assert is_expected is True
assert "still finishing for this thread" in user_message
assert extra is None
def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
thread_id = "thread-cancelling-1"
reset_cancel(thread_id)
request_cancel(thread_id)
exc = BusyError(request_id=thread_id)
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
assert code == "TURN_CANCELLING"
assert severity == "info"
assert is_expected is True
assert "stopping" in user_message
assert isinstance(extra, dict)
assert "retry_after_ms" in extra
def test_premium_classification_is_error_code_driven():
classifier_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/chat-error-classifier.ts"
)
source = classifier_path.read_text(encoding="utf-8")
assert "PREMIUM_KEYWORDS" not in source
assert "RATE_LIMIT_KEYWORDS" not in source
assert "normalized.includes(" not in source
assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source
def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook():
page_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
)
source = page_path.read_text(encoding="utf-8")
assert "onPreAcceptFailure?: () => Promise<void>;" in source
assert "if (!accepted) {" in source
assert "await onPreAcceptFailure?.();" in source
assert "await onAcceptedStreamError?.();" in source
assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source
assert "setMessageDocumentsMap((prev) => {" in source
def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
user_message_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/components/assistant-ui/user-message.tsx"
)
source = user_message_path.read_text(encoding="utf-8")
assert "Not sent. Edit and retry." not in source
assert "failed_pre_accept" not in source
def test_network_send_failures_use_unified_retry_toast_message():
classifier_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/chat-error-classifier.ts"
)
classifier_source = classifier_path.read_text(encoding="utf-8")
request_errors_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/chat-request-errors.ts"
)
request_errors_source = request_errors_path.read_text(encoding="utf-8")
assert '"send_failed_pre_accept"' in classifier_source
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
assert 'errorCode === "TURN_CANCELLING"' in classifier_source
assert "if (withCode.code) return withCode.code;" in classifier_source
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
assert "const passthroughCodes = new Set([" in request_errors_source
assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
assert '"THREAD_BUSY"' in request_errors_source
assert '"TURN_CANCELLING"' in request_errors_source
assert '"AUTH_EXPIRED"' in request_errors_source
assert '"UNAUTHORIZED"' in request_errors_source
assert '"RATE_LIMITED"' in request_errors_source
assert '"NETWORK_ERROR"' in request_errors_source
assert '"STREAM_PARSE_ERROR"' in request_errors_source
assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
assert '"SERVER_ERROR"' in request_errors_source
assert "passthroughCodes.has(existingCode)" in request_errors_source
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
assert "Failed to start chat. Please try again." not in classifier_source
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
page_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
)
source = page_path.read_text(encoding="utf-8")
# Each flow tracks accepted boundary and passes it into shared terminal handling.
# The acceptance boundary is still meaningful post-refactor: it gates
# local-state cleanup (onPreAcceptFailure path) and lets the shared
# terminal handler distinguish pre-accept aborts from in-stream errors.
assert "let newAccepted = false;" in source
assert "let resumeAccepted = false;" in source
assert "let regenerateAccepted = false;" in source
assert "accepted: newAccepted," in source
assert "accepted: resumeAccepted," in source
assert "accepted: regenerateAccepted," in source
# NOTE: The FE-side persistence guards previously asserted here
# ("if (!resumeAccepted) return;", "if (!regenerateAccepted) return;",
# "if (newAccepted && !userPersisted) {") have been intentionally
# removed by the SSE-based message-id handshake refactor. Persistence
# is now server-authoritative: persist_user_turn / persist_assistant_shell
# run inside stream_new_chat / stream_resume_chat unconditionally and
# the FE consumes data-user-message-id / data-assistant-message-id
# SSE events to learn the canonical primary keys. There is therefore
# no FE call-site to guard, and the shared terminal handler relies
# purely on the `accepted` field above (forwarded to onAbort /
# onAcceptedStreamError) to drive UI cleanup. See
# tests/integration/chat/test_message_id_sse.py for the new
# cross-tier ID coherence guarantees.
# The TURN_CANCELLING / THREAD_BUSY retry plumbing is independent
# of the persistence refactor and must still exist on every
# start-stream fetch.
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
assert "computeFallbackTurnCancellingRetryDelay" in source
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
assert 'withMeta.errorCode === "THREAD_BUSY"' in source
assert "await fetchWithTurnCancellingRetry(() =>" in source
def test_cancel_active_turn_route_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
assert "response_model=CancelActiveTurnResponse" in source
assert 'status="cancelling",' in source
assert 'error_code="TURN_CANCELLING",' in source
assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
assert "retry_after_at=" in source
assert 'status="idle",' in source
assert 'error_code="NO_ACTIVE_TURN",' in source
def test_turn_status_route_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
assert "response_model=TurnStatusResponse" in source
assert "_build_turn_status_payload(thread_id)" in source
assert "Permission.CHATS_READ.value" in source
assert "_raise_if_thread_busy_for_start(" in source
def test_turn_cancelling_retry_policy_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
assert "def _compute_turn_cancelling_retry_delay(" in source
assert "retry-after-ms" in source
assert '"Retry-After"' in source
assert '"errorCode": "TURN_CANCELLING"' in source
def test_turn_status_sse_contract_exists():
stream_source = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/tasks/chat/stream_new_chat.py"
).read_text(encoding="utf-8")
state_source = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/streaming-state.ts"
).read_text(encoding="utf-8")
pipeline_source = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/stream-pipeline.ts"
).read_text(encoding="utf-8")
assert '"turn-status"' in stream_source
assert '"status": "busy"' in stream_source
assert '"status": "idle"' in stream_source
assert 'type: "data-turn-status"' in state_source
assert 'case "data-turn-status":' in pipeline_source
assert "end_turn(str(chat_id))" in stream_source