mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(chat): delete legacy stream_new_chat monolith (cutover complete)
The flows orchestrators (new_chat/resume_chat) are now the sole live path after the byte-for-byte differential proof, so the monolith and its monolith-vs-flows parity scaffolding are removed. - Repoint the last live importer (anonymous_chat_routes) to streaming.agent.event_loop.stream_agent_events + shared.stream_result.StreamResult (drop-in; the keyword-only fallback-commit params default to inert for anon). - Repoint e2e launcher patch targets to flows.shared.llm_bundle. - Repoint helper unit tests (chunk_parts, thinking-step ids, tool-input streaming) to their flows homes to preserve coverage. - Delete the monolith, the contract test, and the parity tests (parallel_refactor, stage_1, stage_2, orchestrator_frame) whose sole purpose was comparing against the now-removed monolith. Full suite green (2622 passed, 1 skipped); the two excluded live-app dirs (document_upload, composio) have a pre-existing, env-gated registration 404 unrelated to this change.
This commit is contained in:
parent
b9937cf4b1
commit
5b45f78a16
12 changed files with 25 additions and 5028 deletions
|
|
@ -356,7 +356,8 @@ async def stream_anonymous_chat(
|
||||||
from app.db import shielded_async_session
|
from app.db import shielded_async_session
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
from app.services.token_tracking_service import start_turn
|
from app.services.token_tracking_service import start_turn
|
||||||
from app.tasks.chat.stream_new_chat import StreamResult, _stream_agent_events
|
from app.tasks.chat.streaming.agent.event_loop import stream_agent_events
|
||||||
|
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||||
|
|
||||||
accumulator = start_turn()
|
accumulator = start_turn()
|
||||||
streaming_service = VercelStreamingService()
|
streaming_service = VercelStreamingService()
|
||||||
|
|
@ -419,7 +420,7 @@ async def stream_anonymous_chat(
|
||||||
|
|
||||||
stream_result = StreamResult()
|
stream_result = StreamResult()
|
||||||
|
|
||||||
async for sse in _stream_agent_events(
|
async for sse in stream_agent_events(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
config=langgraph_config,
|
config=langgraph_config,
|
||||||
input_data=input_state,
|
input_data=input_state,
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -247,11 +247,11 @@ def _patch_llm_bindings() -> None:
|
||||||
fake_create_chat_litellm_from_config,
|
fake_create_chat_litellm_from_config,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config",
|
"app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_agent_config",
|
||||||
fake_create_chat_litellm_from_agent_config,
|
fake_create_chat_litellm_from_agent_config,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_config",
|
"app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config",
|
||||||
fake_create_chat_litellm_from_config,
|
fake_create_chat_litellm_from_config,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -220,11 +220,11 @@ def _patch_llm_bindings() -> None:
|
||||||
fake_create_chat_litellm_from_config,
|
fake_create_chat_litellm_from_config,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_agent_config",
|
"app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_agent_config",
|
||||||
fake_create_chat_litellm_from_agent_config,
|
fake_create_chat_litellm_from_agent_config,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"app.tasks.chat.stream_new_chat.create_chat_litellm_from_config",
|
"app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config",
|
||||||
fake_create_chat_litellm_from_config,
|
fake_create_chat_litellm_from_config,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,457 +0,0 @@
|
||||||
"""Byte-for-byte frame parity: legacy monolith vs refactored flows orchestrators.
|
|
||||||
|
|
||||||
The agent-content portion of the stream (`text-*`, tool cards, thinking-step
|
|
||||||
updates) flows through **shared** code in both implementations
|
|
||||||
(`stream_output` -> `EventRelay.relay` -> handlers), so it cannot diverge. The
|
|
||||||
only independently-written part is the *orchestrator glue*: the initial frames,
|
|
||||||
persistence-handshake frames, error/terminal branches, and final frames.
|
|
||||||
|
|
||||||
This module drives BOTH ``stream_new_chat`` implementations (legacy
|
|
||||||
``app.tasks.chat.stream_new_chat`` and the refactored
|
|
||||||
``app.tasks.chat.streaming.flows``) through the deterministic glue paths and
|
|
||||||
asserts the emitted SSE frame sequences are **byte-for-byte identical**. These
|
|
||||||
are the paths where divergence could hide; the agent-streaming portion is shared
|
|
||||||
and is covered separately.
|
|
||||||
|
|
||||||
Determinism is enforced by:
|
|
||||||
* freezing ``time.time`` (so ``turn_id = f"{chat_id}:{ms}"`` is stable),
|
|
||||||
* a deterministic ``uuid`` sequence for the streaming-service id generators,
|
|
||||||
* stubbing every DB/LLM/agent seam (LLM resolution, persistence, connector,
|
|
||||||
checkpointer, session) to fixed values.
|
|
||||||
|
|
||||||
Cutover gate: when these are green, the live callers can be flipped to the
|
|
||||||
flows orchestrators.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import app.services.new_streaming_service as _nss
|
|
||||||
from app.tasks.chat.stream_new_chat import (
|
|
||||||
stream_new_chat as old_stream_new_chat,
|
|
||||||
stream_resume_chat as old_stream_resume_chat,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.flows import (
|
|
||||||
stream_new_chat as new_stream_new_chat,
|
|
||||||
stream_resume_chat as new_stream_resume_chat,
|
|
||||||
)
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
|
||||||
|
|
||||||
_FIXED_EPOCH = 1_700_000_000.0 # -> turn_id "<chat_id>:1700000000000"
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------- #
|
|
||||||
# Deterministic uuid for the streaming-service id generators
|
|
||||||
# --------------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
|
|
||||||
class _SeqUUID:
|
|
||||||
"""Drop-in for the ``uuid`` module used by ``new_streaming_service``.
|
|
||||||
|
|
||||||
Only ``uuid4().hex`` is consumed by the id generators. We hand out a
|
|
||||||
monotonic, zero-padded hex so two runs that emit the same number of ids in
|
|
||||||
the same order produce identical bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._n = 0
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
self._n = 0
|
|
||||||
|
|
||||||
def uuid4(self) -> SimpleNamespace:
|
|
||||||
self._n += 1
|
|
||||||
return SimpleNamespace(hex=f"{self._n:032x}")
|
|
||||||
|
|
||||||
|
|
||||||
_SEQ = _SeqUUID()
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------- #
|
|
||||||
# Fake session: the orchestrator owns ``async_session_maker()``; for the glue
|
|
||||||
# paths every real consumer is stubbed, so a no-op session suffices.
|
|
||||||
# --------------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeResult:
|
|
||||||
"""Empty-everything SQLAlchemy ``Result`` stand-in for pre-stream reads."""
|
|
||||||
|
|
||||||
def scalars(self) -> "_FakeResult":
|
|
||||||
return self
|
|
||||||
|
|
||||||
def first(self) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def all(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def one_or_none(self) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def scalar_one_or_none(self) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def scalar(self) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def fetchall(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(())
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeSession:
|
|
||||||
async def commit(self) -> None: # pragma: no cover - trivial
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def rollback(self) -> None: # pragma: no cover - trivial
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def close(self) -> None: # pragma: no cover - trivial
|
|
||||||
return None
|
|
||||||
|
|
||||||
def expunge_all(self) -> None: # pragma: no cover - trivial
|
|
||||||
return None
|
|
||||||
|
|
||||||
def add(self, *a: Any, **k: Any) -> None: # pragma: no cover - trivial
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def flush(self, *a: Any, **k: Any) -> None: # pragma: no cover
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def execute(self, *a: Any, **k: Any) -> _FakeResult:
|
|
||||||
return _FakeResult()
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeConnectorService:
|
|
||||||
def __init__(self, *a: Any, **k: Any) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def get_connector_by_type(self, *a: Any, **k: Any) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _patch(monkeypatch: pytest.MonkeyPatch, target: str, value: Any) -> None:
|
|
||||||
"""``setattr`` that tolerates a missing attr (binding may be local-import)."""
|
|
||||||
monkeypatch.setattr(target, value, raising=False)
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_common(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
*,
|
|
||||||
pin_raises: ValueError | None = None,
|
|
||||||
resolved_id: int = -1,
|
|
||||||
llm_load_ok: bool = True,
|
|
||||||
persist_user_id: int | None = 101,
|
|
||||||
persist_assistant_id: int | None = 102,
|
|
||||||
) -> None:
|
|
||||||
"""Patch every glue seam in BOTH implementations to deterministic values."""
|
|
||||||
# Time -> stable turn_id and any retry_after_at.
|
|
||||||
monkeypatch.setattr("time.time", lambda: _FIXED_EPOCH)
|
|
||||||
|
|
||||||
# Deterministic streaming-service ids.
|
|
||||||
monkeypatch.setattr(_nss, "uuid", _SEQ)
|
|
||||||
|
|
||||||
fake_model = MagicMock(name="scripted_llm")
|
|
||||||
|
|
||||||
# --- session ---
|
|
||||||
for tgt in (
|
|
||||||
"app.tasks.chat.stream_new_chat.async_session_maker",
|
|
||||||
"app.tasks.chat.streaming.flows.new_chat.orchestrator.async_session_maker",
|
|
||||||
"app.tasks.chat.streaming.flows.resume_chat.orchestrator.async_session_maker",
|
|
||||||
):
|
|
||||||
_patch(monkeypatch, tgt, _FakeSession)
|
|
||||||
|
|
||||||
# --- connector service ---
|
|
||||||
for tgt in (
|
|
||||||
"app.tasks.chat.stream_new_chat.ConnectorService",
|
|
||||||
"app.tasks.chat.streaming.flows.shared.pre_stream_setup.ConnectorService",
|
|
||||||
):
|
|
||||||
_patch(monkeypatch, tgt, _FakeConnectorService)
|
|
||||||
|
|
||||||
# --- checkpointer ---
|
|
||||||
for tgt in (
|
|
||||||
"app.tasks.chat.stream_new_chat.get_checkpointer",
|
|
||||||
"app.tasks.chat.streaming.flows.shared.pre_stream_setup.get_checkpointer",
|
|
||||||
):
|
|
||||||
_patch(monkeypatch, tgt, AsyncMock(return_value=MagicMock(name="checkpointer")))
|
|
||||||
|
|
||||||
# --- agent factory (built but never streamed on glue paths) ---
|
|
||||||
# Resume routing awaits ``agent.aget_state`` before persist, so the fake
|
|
||||||
# agent exposes async state accessors returning an empty (no-interrupt)
|
|
||||||
# snapshot. ``astream_events`` is never reached on glue paths.
|
|
||||||
fake_agent = MagicMock(name="agent")
|
|
||||||
fake_agent.aget_state = AsyncMock(
|
|
||||||
return_value=SimpleNamespace(values={}, tasks=[], interrupts=[], next=())
|
|
||||||
)
|
|
||||||
fake_agent.aupdate_state = AsyncMock(return_value=None)
|
|
||||||
agent_factory = AsyncMock(return_value=fake_agent)
|
|
||||||
for tgt in (
|
|
||||||
"app.tasks.chat.stream_new_chat.create_multi_agent_chat_deep_agent",
|
|
||||||
"app.tasks.chat.streaming.flows.new_chat.orchestrator.create_multi_agent_chat_deep_agent",
|
|
||||||
"app.tasks.chat.streaming.flows.resume_chat.orchestrator.create_multi_agent_chat_deep_agent",
|
|
||||||
):
|
|
||||||
_patch(monkeypatch, tgt, agent_factory)
|
|
||||||
|
|
||||||
# --- LLM resolution (auto-pin) ---
|
|
||||||
if pin_raises is not None:
|
|
||||||
async def _resolver(*a: Any, **k: Any):
|
|
||||||
raise pin_raises
|
|
||||||
else:
|
|
||||||
async def _resolver(*a: Any, **k: Any):
|
|
||||||
return SimpleNamespace(resolved_llm_config_id=resolved_id)
|
|
||||||
|
|
||||||
_patch(monkeypatch, "app.services.auto_model_pin_service.resolve_or_get_pinned_llm_config_id", _resolver)
|
|
||||||
_patch(monkeypatch, "app.tasks.chat.stream_new_chat.resolve_or_get_pinned_llm_config_id", _resolver)
|
|
||||||
_patch(
|
|
||||||
monkeypatch,
|
|
||||||
"app.tasks.chat.streaming.flows.new_chat.auto_pin.resolve_or_get_pinned_llm_config_id",
|
|
||||||
_resolver,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- LLM bundle ---
|
|
||||||
sentinel_cfg = object() if llm_load_ok else None
|
|
||||||
_patch(monkeypatch, "app.tasks.chat.stream_new_chat.load_global_llm_config_by_id", lambda cid: sentinel_cfg)
|
|
||||||
_patch(
|
|
||||||
monkeypatch,
|
|
||||||
"app.tasks.chat.streaming.flows.shared.llm_bundle.load_global_llm_config_by_id",
|
|
||||||
lambda cid: sentinel_cfg,
|
|
||||||
)
|
|
||||||
_patch(monkeypatch, "app.tasks.chat.stream_new_chat.create_chat_litellm_from_config", lambda cfg: fake_model)
|
|
||||||
_patch(
|
|
||||||
monkeypatch,
|
|
||||||
"app.tasks.chat.streaming.flows.shared.llm_bundle.create_chat_litellm_from_config",
|
|
||||||
lambda cfg: fake_model,
|
|
||||||
)
|
|
||||||
# agent_config := None keeps premium + capability gates inert and identical.
|
|
||||||
from app.agents.shared.llm_config import AgentConfig
|
|
||||||
|
|
||||||
monkeypatch.setattr(AgentConfig, "from_yaml_config", staticmethod(lambda cfg: None))
|
|
||||||
|
|
||||||
# --- persistence ---
|
|
||||||
async def _persist_user(*a: Any, **k: Any):
|
|
||||||
return persist_user_id
|
|
||||||
|
|
||||||
async def _persist_assistant(*a: Any, **k: Any):
|
|
||||||
return persist_assistant_id
|
|
||||||
|
|
||||||
async def _finalize(*a: Any, **k: Any):
|
|
||||||
return None
|
|
||||||
|
|
||||||
for mod in (
|
|
||||||
"app.tasks.chat.persistence",
|
|
||||||
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn",
|
|
||||||
):
|
|
||||||
_patch(monkeypatch, f"{mod}.persist_user_turn", _persist_user)
|
|
||||||
_patch(monkeypatch, f"{mod}.persist_assistant_shell", _persist_assistant)
|
|
||||||
# Resume binds ``persist_assistant_shell`` in its own assistant_shell module.
|
|
||||||
_patch(
|
|
||||||
monkeypatch,
|
|
||||||
"app.tasks.chat.streaming.flows.resume_chat.assistant_shell.persist_assistant_shell",
|
|
||||||
_persist_assistant,
|
|
||||||
)
|
|
||||||
_patch(monkeypatch, "app.tasks.chat.persistence.finalize_assistant_turn", _finalize)
|
|
||||||
|
|
||||||
# --- collaboration flags ---
|
|
||||||
async def _noop(*a: Any, **k: Any):
|
|
||||||
return None
|
|
||||||
|
|
||||||
for tgt in (
|
|
||||||
"app.tasks.chat.stream_new_chat.set_ai_responding",
|
|
||||||
"app.tasks.chat.stream_new_chat.clear_ai_responding",
|
|
||||||
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn.set_ai_responding",
|
|
||||||
"app.services.chat_session_state_service.set_ai_responding",
|
|
||||||
"app.services.chat_session_state_service.clear_ai_responding",
|
|
||||||
):
|
|
||||||
_patch(monkeypatch, tgt, _noop)
|
|
||||||
|
|
||||||
|
|
||||||
async def _collect(genfunc: Any, **kwargs: Any) -> list[str]:
|
|
||||||
frames: list[str] = []
|
|
||||||
async for frame in genfunc(**kwargs):
|
|
||||||
frames.append(frame)
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_both(kwargs: dict[str, Any]) -> tuple[list[str], list[str]]:
|
|
||||||
"""Drive both NEW-chat implementations on identical inputs."""
|
|
||||||
_SEQ.reset()
|
|
||||||
old = await _collect(old_stream_new_chat, **kwargs)
|
|
||||||
_SEQ.reset()
|
|
||||||
new = await _collect(new_stream_new_chat, **kwargs)
|
|
||||||
return old, new
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_both_resume(kwargs: dict[str, Any]) -> tuple[list[str], list[str]]:
|
|
||||||
"""Drive both RESUME-chat implementations on identical inputs."""
|
|
||||||
_SEQ.reset()
|
|
||||||
old = await _collect(old_stream_resume_chat, **kwargs)
|
|
||||||
_SEQ.reset()
|
|
||||||
new = await _collect(new_stream_resume_chat, **kwargs)
|
|
||||||
return old, new
|
|
||||||
|
|
||||||
|
|
||||||
def _assert_parity(old: list[str], new: list[str]) -> None:
|
|
||||||
"""Byte-for-byte equality with a readable first-divergence message."""
|
|
||||||
for i, (a, b) in enumerate(zip(old, new, strict=False)):
|
|
||||||
assert a == b, f"frame[{i}] differs:\n old={a!r}\n new={b!r}"
|
|
||||||
assert len(old) == len(new), (
|
|
||||||
f"frame count differs: old={len(old)} new={len(new)}\n"
|
|
||||||
f" old tail={old[len(new):]!r}\n new tail={new[len(old):]!r}"
|
|
||||||
)
|
|
||||||
assert old[-1].strip() == "data: [DONE]"
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------- #
|
|
||||||
# NEW-chat scenarios
|
|
||||||
# --------------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
_NEW_KW = dict(user_query="hi", search_space_id=1, chat_id=42, user_id=None)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auto_pin_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""Auto-pin raises -> identical ``[error, DONE]`` from both."""
|
|
||||||
_apply_common(monkeypatch, pin_raises=ValueError("no eligible config"))
|
|
||||||
old, new = await _run_both(dict(_NEW_KW))
|
|
||||||
_assert_parity(old, new)
|
|
||||||
assert len(old) == 2
|
|
||||||
assert '"errorCode": "SERVER_ERROR"' in old[0]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_llm_load_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""LLM bundle load fails -> identical ``[error, DONE]`` from both."""
|
|
||||||
_apply_common(monkeypatch, llm_load_ok=False)
|
|
||||||
old, new = await _run_both(dict(_NEW_KW))
|
|
||||||
_assert_parity(old, new)
|
|
||||||
assert len(old) == 2
|
|
||||||
assert '"errorCode": "SERVER_ERROR"' in old[0]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_persist_user_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""User-turn persist returns None.
|
|
||||||
|
|
||||||
Exercises the full initial-frame ordering (start, start-step, turn-info,
|
|
||||||
turn-status busy), the MESSAGE_PERSIST_FAILED error, and final frames.
|
|
||||||
"""
|
|
||||||
_apply_common(monkeypatch, persist_user_id=None)
|
|
||||||
old, new = await _run_both(dict(_NEW_KW))
|
|
||||||
_assert_parity(old, new)
|
|
||||||
assert '"type": "start"' in old[0]
|
|
||||||
assert '"chat_turn_id": "42:1700000000000"' in old[2]
|
|
||||||
assert any('"errorCode": "MESSAGE_PERSIST_FAILED"' in f for f in old)
|
|
||||||
assert any('"type": "finish"' in f for f in old)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_persist_assistant_failure_parity(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""Assistant-shell persist returns None.
|
|
||||||
|
|
||||||
Adds the ``data-user-message-id`` handshake frame ahead of the error.
|
|
||||||
"""
|
|
||||||
_apply_common(monkeypatch, persist_user_id=101, persist_assistant_id=None)
|
|
||||||
old, new = await _run_both(dict(_NEW_KW))
|
|
||||||
_assert_parity(old, new)
|
|
||||||
assert any('"data-user-message-id"' in f and '"message_id": 101' in f for f in old)
|
|
||||||
assert any('"errorCode": "MESSAGE_PERSIST_FAILED"' in f for f in old)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_prestream_exception_parity(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""A pre-stream failure routes both through the top-level ``except`` path.
|
|
||||||
|
|
||||||
Resolver returns a non-int so ``turn_id`` math / downstream use raises after
|
|
||||||
the span opens but before initial frames: both must emit the identical
|
|
||||||
``busy -> error -> idle -> finish-step -> finish -> DONE`` terminal sequence.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def _bad_resolver(*a: Any, **k: Any):
|
|
||||||
raise RuntimeError("boom in pre-stream")
|
|
||||||
|
|
||||||
_apply_common(monkeypatch)
|
|
||||||
# Override the resolver with a non-ValueError so the classified early-error
|
|
||||||
# branches don't catch it -> top-level except path.
|
|
||||||
for tgt in (
|
|
||||||
"app.services.auto_model_pin_service.resolve_or_get_pinned_llm_config_id",
|
|
||||||
"app.tasks.chat.stream_new_chat.resolve_or_get_pinned_llm_config_id",
|
|
||||||
"app.tasks.chat.streaming.flows.new_chat.auto_pin.resolve_or_get_pinned_llm_config_id",
|
|
||||||
):
|
|
||||||
_patch(monkeypatch, tgt, _bad_resolver)
|
|
||||||
old, new = await _run_both(dict(_NEW_KW))
|
|
||||||
_assert_parity(old, new)
|
|
||||||
assert any('"type": "error"' in f for f in old)
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------- #
|
|
||||||
# RESUME-chat scenarios (no title-generation path -> fully deterministic)
|
|
||||||
# --------------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
_RESUME_KW = dict(chat_id=42, search_space_id=1, decisions=[], user_id=None)
|
|
||||||
|
|
||||||
|
|
||||||
async def _collect_resume_old() -> list[str]:
|
|
||||||
_SEQ.reset()
|
|
||||||
return await _collect(old_stream_resume_chat, **dict(_RESUME_KW))
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: KNOWN, INTENTIONAL DIVERGENCE (flows fixes a latent monolith bug).
|
|
||||||
#
|
|
||||||
# In ``stream_resume_chat`` the monolith defines ``_resume_premium_request_id``
|
|
||||||
# (line ~2363) AFTER the auto-pin / LLM-load early-return points (~2346 / ~2356).
|
|
||||||
# Its ``finally`` block (line ~2918) reads that variable, so a resume turn whose
|
|
||||||
# auto-pin raises or whose LLM bundle fails to load crashes with
|
|
||||||
# ``UnboundLocalError`` instead of emitting a clean terminal-error frame. The
|
|
||||||
# refactored flows orchestrator does NOT have this bug — it emits the proper
|
|
||||||
# ``[error, DONE]`` sequence. We assert the divergence explicitly so the cutover
|
|
||||||
# is a documented behavior IMPROVEMENT rather than a silent change.
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_resume_auto_pin_failure_flows_fixes_monolith_crash(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
_apply_common(monkeypatch, pin_raises=ValueError("no eligible config"))
|
|
||||||
# Monolith: latent UnboundLocalError in the finally clause.
|
|
||||||
with pytest.raises(UnboundLocalError, match="_resume_premium_request_id"):
|
|
||||||
await _collect_resume_old()
|
|
||||||
# Flows: clean terminal error.
|
|
||||||
_SEQ.reset()
|
|
||||||
new = await _collect(new_stream_resume_chat, **dict(_RESUME_KW))
|
|
||||||
assert len(new) == 2
|
|
||||||
assert new[-1].strip() == "data: [DONE]"
|
|
||||||
assert '"type": "error"' in new[0]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_resume_llm_load_failure_flows_fixes_monolith_crash(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
_apply_common(monkeypatch, llm_load_ok=False)
|
|
||||||
with pytest.raises(UnboundLocalError, match="_resume_premium_request_id"):
|
|
||||||
await _collect_resume_old()
|
|
||||||
_SEQ.reset()
|
|
||||||
new = await _collect(new_stream_resume_chat, **dict(_RESUME_KW))
|
|
||||||
assert len(new) == 2
|
|
||||||
assert new[-1].strip() == "data: [DONE]"
|
|
||||||
assert '"type": "error"' in new[0]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_resume_persist_assistant_failure_parity(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
"""Resume emits NO user-message-id frame; only the assistant handshake path."""
|
|
||||||
_apply_common(monkeypatch, persist_assistant_id=None)
|
|
||||||
old, new = await _run_both_resume(dict(_RESUME_KW))
|
|
||||||
_assert_parity(old, new)
|
|
||||||
assert not any('"data-user-message-id"' in f for f in old)
|
|
||||||
assert any('"chat_turn_id": "42:1700000000000"' in f for f in old)
|
|
||||||
|
|
@ -1,584 +0,0 @@
|
||||||
"""Parity gate for the parallel refactor of ``stream_new_chat.py``.
|
|
||||||
|
|
||||||
The new tree under ``app.tasks.chat.streaming.flows`` is built side-by-side with
|
|
||||||
the legacy monolithic ``app.tasks.chat.stream_new_chat`` so we can cut over
|
|
||||||
atomically. This file pins externally-observable behaviour at module
|
|
||||||
boundaries so a divergence between the two trees fails loudly *before* the
|
|
||||||
cutover.
|
|
||||||
|
|
||||||
What we verify:
|
|
||||||
|
|
||||||
1. **Signature parity** — ``stream_new_chat`` / ``stream_resume_chat`` from
|
|
||||||
the new tree have the same call signature as the originals.
|
|
||||||
2. **Helper extraction parity** — the SRP modules in ``flows/`` produce the
|
|
||||||
same outputs as the inline code in the legacy file for representative
|
|
||||||
inputs (initial thinking step, image-capability gate, runtime context,
|
|
||||||
SSE frame sequences, token-usage frame shape, persistence guards).
|
|
||||||
3. **Wrapper delegation** — wrappers like ``load_llm_bundle`` /
|
|
||||||
``can_recover_provider_rate_limit`` exist and are addressable.
|
|
||||||
|
|
||||||
Delete this file along with ``stream_new_chat.py`` once the cutover is done
|
|
||||||
(see the parent refactor plan).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import inspect
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.agents.shared.context import SurfSenseContextSchema
|
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
|
||||||
from app.tasks.chat.stream_new_chat import (
|
|
||||||
stream_new_chat as old_stream_new_chat,
|
|
||||||
stream_resume_chat as old_stream_resume_chat,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.flows import (
|
|
||||||
stream_new_chat as new_stream_new_chat,
|
|
||||||
stream_resume_chat as new_stream_resume_chat,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.flows.new_chat.initial_thinking_step import (
|
|
||||||
build_initial_thinking_step,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.flows.new_chat.llm_capability import (
|
|
||||||
check_image_input_capability,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.flows.new_chat.persistence_spawn import (
|
|
||||||
await_persist_task,
|
|
||||||
spawn_persist_assistant_shell_task,
|
|
||||||
spawn_persist_user_task,
|
|
||||||
spawn_set_ai_responding_bg,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.flows.new_chat.runtime_context import (
|
|
||||||
build_new_chat_runtime_context,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.flows.resume_chat.runtime_context import (
|
|
||||||
build_resume_chat_runtime_context,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.flows.shared.finalize_emit import iter_token_usage_frame
|
|
||||||
from app.tasks.chat.streaming.flows.shared.first_frames import (
|
|
||||||
iter_final_frames,
|
|
||||||
iter_initial_frames,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.flows.shared.llm_bundle import load_llm_bundle
|
|
||||||
from app.tasks.chat.streaming.flows.shared.premium_quota import (
|
|
||||||
PremiumReservation,
|
|
||||||
needs_premium_quota,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.flows.shared.rate_limit_recovery import (
|
|
||||||
can_recover_provider_rate_limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------- signature
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_annotation(ann: Any) -> str:
|
|
||||||
"""Compare-friendly form for an annotation.
|
|
||||||
|
|
||||||
The legacy ``stream_new_chat.py`` does NOT use ``from __future__ import
|
|
||||||
annotations``, so its annotations are evaluated at import time and come
|
|
||||||
back as type objects / typing generics. The new tree DOES use it, so its
|
|
||||||
annotations are PEP-563 strings.
|
|
||||||
|
|
||||||
Both reprs describe the same types — strip the module prefixes / typing
|
|
||||||
namespace + the ``<class 'X'>`` wrapper so we compare the canonical
|
|
||||||
declared form.
|
|
||||||
"""
|
|
||||||
if ann is inspect.Signature.empty:
|
|
||||||
return ""
|
|
||||||
raw = ann if isinstance(ann, str) else repr(ann)
|
|
||||||
cleaned = (
|
|
||||||
raw.replace("typing.", "")
|
|
||||||
.replace("collections.abc.", "")
|
|
||||||
.replace("app.db.", "")
|
|
||||||
.replace("app.agents.shared.filesystem_selection.", "")
|
|
||||||
.replace("app.agents.shared.context.", "")
|
|
||||||
)
|
|
||||||
# Unwrap ``<class 'int'>`` → ``int`` (legacy-side type objects).
|
|
||||||
if cleaned.startswith("<class '") and cleaned.endswith("'>"):
|
|
||||||
cleaned = cleaned[len("<class '") : -len("'>")]
|
|
||||||
return cleaned
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_sig(sig: inspect.Signature) -> list[tuple[str, Any, str]]:
|
|
||||||
return [
|
|
||||||
(p.name, p.default, _normalize_annotation(p.annotation))
|
|
||||||
for p in sig.parameters.values()
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_new_chat_signature_matches_legacy() -> None:
|
|
||||||
old = inspect.signature(old_stream_new_chat)
|
|
||||||
new = inspect.signature(new_stream_new_chat)
|
|
||||||
assert _normalize_sig(new) == _normalize_sig(old)
|
|
||||||
assert _normalize_annotation(new.return_annotation) == _normalize_annotation(
|
|
||||||
old.return_annotation
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_resume_chat_signature_matches_legacy() -> None:
|
|
||||||
old = inspect.signature(old_stream_resume_chat)
|
|
||||||
new = inspect.signature(new_stream_resume_chat)
|
|
||||||
assert _normalize_sig(new) == _normalize_sig(old)
|
|
||||||
assert _normalize_annotation(new.return_annotation) == _normalize_annotation(
|
|
||||||
old.return_annotation
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_orchestrators_are_async_generator_functions() -> None:
|
|
||||||
assert inspect.isasyncgenfunction(new_stream_new_chat)
|
|
||||||
assert inspect.isasyncgenfunction(new_stream_resume_chat)
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------ initial thinking
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"user_query, image_urls, expected_title, expected_action",
|
|
||||||
[
|
|
||||||
("hello world", None, "Understanding your request", "Processing"),
|
|
||||||
(
|
|
||||||
"",
|
|
||||||
["data:image/png;base64,AAA"],
|
|
||||||
"Understanding your request",
|
|
||||||
"Processing",
|
|
||||||
),
|
|
||||||
("", None, "Understanding your request", "Processing"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_initial_thinking_step_branches(
|
|
||||||
user_query: str,
|
|
||||||
image_urls: list[str] | None,
|
|
||||||
expected_title: str,
|
|
||||||
expected_action: str,
|
|
||||||
) -> None:
|
|
||||||
step = build_initial_thinking_step(
|
|
||||||
user_query=user_query,
|
|
||||||
user_image_data_urls=image_urls,
|
|
||||||
)
|
|
||||||
assert step.step_id == "thinking-1"
|
|
||||||
assert step.title == expected_title
|
|
||||||
assert len(step.items) == 1
|
|
||||||
assert step.items[0].startswith(f"{expected_action}: ")
|
|
||||||
|
|
||||||
|
|
||||||
def test_initial_thinking_step_truncates_long_query() -> None:
|
|
||||||
long_query = "x" * 200
|
|
||||||
step = build_initial_thinking_step(
|
|
||||||
user_query=long_query,
|
|
||||||
user_image_data_urls=None,
|
|
||||||
)
|
|
||||||
# 80-char truncation + ellipsis, sandwiched after "Processing: ".
|
|
||||||
assert "..." in step.items[0]
|
|
||||||
item = step.items[0]
|
|
||||||
payload = item[len("Processing: ") :]
|
|
||||||
assert payload.startswith("x" * 80) and payload.endswith("...")
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------ capability gate
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_capability_passes_without_images() -> None:
|
|
||||||
assert (
|
|
||||||
check_image_input_capability(user_image_data_urls=None, agent_config=None)
|
|
||||||
is None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_capability_passes_when_capability_unknown() -> None:
|
|
||||||
"""Unknown / unmapped models are not blocked — only models LiteLLM has
|
|
||||||
*explicitly* marked text-only trip the gate."""
|
|
||||||
|
|
||||||
class _AgentConfig:
|
|
||||||
provider = "openrouter"
|
|
||||||
model_name = "unknown-mystery-model"
|
|
||||||
custom_provider = None
|
|
||||||
config_name = "Unknown"
|
|
||||||
litellm_params: dict[str, Any] = {}
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"app.services.provider_capabilities.is_known_text_only_chat_model",
|
|
||||||
return_value=False,
|
|
||||||
):
|
|
||||||
assert (
|
|
||||||
check_image_input_capability(
|
|
||||||
user_image_data_urls=["data:image/png;base64,AAA"],
|
|
||||||
agent_config=_AgentConfig(), # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
is None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_capability_blocks_known_text_only_models() -> None:
|
|
||||||
class _AgentConfig:
|
|
||||||
provider = "openai"
|
|
||||||
model_name = "gpt-3.5-turbo"
|
|
||||||
custom_provider = None
|
|
||||||
config_name = "GPT-3.5"
|
|
||||||
litellm_params: dict[str, Any] = {"base_model": "gpt-3.5-turbo"}
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"app.services.provider_capabilities.is_known_text_only_chat_model",
|
|
||||||
return_value=True,
|
|
||||||
):
|
|
||||||
result = check_image_input_capability(
|
|
||||||
user_image_data_urls=["data:image/png;base64,AAA"],
|
|
||||||
agent_config=_AgentConfig(), # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
assert result is not None
|
|
||||||
message, error_code = result
|
|
||||||
assert error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
|
|
||||||
assert "GPT-3.5" in message
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------- runtime ctx
|
|
||||||
|
|
||||||
|
|
||||||
def test_new_chat_runtime_context_prefers_accepted_folder_ids() -> None:
|
|
||||||
"""Post-resolve accepted folder ids win over the raw requested ids."""
|
|
||||||
ctx = build_new_chat_runtime_context(
|
|
||||||
search_space_id=7,
|
|
||||||
mentioned_document_ids=[1, 2],
|
|
||||||
accepted_folder_ids=[10],
|
|
||||||
mentioned_folder_ids=[20, 30],
|
|
||||||
mentioned_connector_ids=None,
|
|
||||||
mentioned_connectors=None,
|
|
||||||
request_id="req",
|
|
||||||
turn_id="t1",
|
|
||||||
)
|
|
||||||
assert isinstance(ctx, SurfSenseContextSchema)
|
|
||||||
assert ctx.search_space_id == 7
|
|
||||||
assert list(ctx.mentioned_document_ids) == [1, 2]
|
|
||||||
assert list(ctx.mentioned_folder_ids) == [10]
|
|
||||||
assert ctx.request_id == "req"
|
|
||||||
assert ctx.turn_id == "t1"
|
|
||||||
|
|
||||||
|
|
||||||
def test_new_chat_runtime_context_falls_back_to_mentioned_folder_ids() -> None:
|
|
||||||
"""With no accepted ids, the raw requested folder ids flow through."""
|
|
||||||
ctx = build_new_chat_runtime_context(
|
|
||||||
search_space_id=7,
|
|
||||||
mentioned_document_ids=None,
|
|
||||||
accepted_folder_ids=[],
|
|
||||||
mentioned_folder_ids=[20, 30],
|
|
||||||
mentioned_connector_ids=None,
|
|
||||||
mentioned_connectors=None,
|
|
||||||
request_id=None,
|
|
||||||
turn_id="t2",
|
|
||||||
)
|
|
||||||
assert list(ctx.mentioned_folder_ids) == [20, 30]
|
|
||||||
|
|
||||||
|
|
||||||
def test_new_chat_runtime_context_propagates_connector_mentions() -> None:
|
|
||||||
"""@-selected connector ids/accounts ride onto the runtime context schema.
|
|
||||||
|
|
||||||
Parity with the legacy ``stream_new_chat`` runtime context, which set both
|
|
||||||
``mentioned_connector_ids`` and ``mentioned_connectors`` on the schema.
|
|
||||||
"""
|
|
||||||
connectors = [{"id": 5, "connector_type": "SLACK_CONNECTOR", "title": "acme"}]
|
|
||||||
ctx = build_new_chat_runtime_context(
|
|
||||||
search_space_id=7,
|
|
||||||
mentioned_document_ids=None,
|
|
||||||
accepted_folder_ids=[],
|
|
||||||
mentioned_folder_ids=None,
|
|
||||||
mentioned_connector_ids=[5],
|
|
||||||
mentioned_connectors=connectors,
|
|
||||||
request_id=None,
|
|
||||||
turn_id="t3",
|
|
||||||
)
|
|
||||||
assert list(ctx.mentioned_connector_ids) == [5]
|
|
||||||
assert list(ctx.mentioned_connectors) == connectors
|
|
||||||
|
|
||||||
|
|
||||||
def test_resume_chat_runtime_context_empty_mention_lists() -> None:
|
|
||||||
ctx = build_resume_chat_runtime_context(
|
|
||||||
search_space_id=42, request_id="req-r", turn_id="t-r"
|
|
||||||
)
|
|
||||||
assert ctx.search_space_id == 42
|
|
||||||
assert ctx.request_id == "req-r"
|
|
||||||
assert ctx.turn_id == "t-r"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------- SSE frames
|
|
||||||
|
|
||||||
|
|
||||||
def test_iter_initial_frames_emits_canonical_sequence() -> None:
|
|
||||||
svc = VercelStreamingService()
|
|
||||||
frames = list(iter_initial_frames(svc, turn_id="42:1700000000000"))
|
|
||||||
# Exactly 4 frames: message_start, start_step, turn-info (turn_id), turn-status (busy).
|
|
||||||
assert len(frames) == 4
|
|
||||||
assert "42:1700000000000" in frames[2]
|
|
||||||
assert '"status":"busy"' in frames[3] or '"status": "busy"' in frames[3]
|
|
||||||
|
|
||||||
|
|
||||||
def test_iter_final_frames_emits_idle_then_finish_done() -> None:
|
|
||||||
svc = VercelStreamingService()
|
|
||||||
frames = list(iter_final_frames(svc))
|
|
||||||
assert len(frames) == 4
|
|
||||||
assert '"status":"idle"' in frames[0] or '"status": "idle"' in frames[0]
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------- token usage frame
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeAccumulator:
|
|
||||||
"""Minimal stand-in covering only the fields ``iter_token_usage_frame`` reads."""
|
|
||||||
|
|
||||||
def __init__(self, summary: Any = None) -> None:
|
|
||||||
self._summary = summary
|
|
||||||
self.calls = [1, 2, 3]
|
|
||||||
self.grand_total = 100
|
|
||||||
self.total_cost_micros = 50_000
|
|
||||||
self.total_prompt_tokens = 60
|
|
||||||
self.total_completion_tokens = 40
|
|
||||||
|
|
||||||
def per_message_summary(self) -> Any:
|
|
||||||
return self._summary
|
|
||||||
|
|
||||||
def serialized_calls(self) -> list[Any]:
|
|
||||||
return list(self.calls)
|
|
||||||
|
|
||||||
|
|
||||||
def test_token_usage_frame_skipped_when_no_summary() -> None:
|
|
||||||
svc = VercelStreamingService()
|
|
||||||
frames = list(
|
|
||||||
iter_token_usage_frame(
|
|
||||||
svc,
|
|
||||||
accumulator=_FakeAccumulator(summary=None), # type: ignore[arg-type]
|
|
||||||
log_label="parity-empty",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert frames == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_token_usage_frame_emitted_when_summary_present() -> None:
|
|
||||||
svc = VercelStreamingService()
|
|
||||||
frames = list(
|
|
||||||
iter_token_usage_frame(
|
|
||||||
svc,
|
|
||||||
accumulator=_FakeAccumulator(summary=[{"m": "x", "t": 100}]), # type: ignore[arg-type]
|
|
||||||
log_label="parity-populated",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert len(frames) == 1
|
|
||||||
# Field shape on the wire is fixed by the FE; assert each surfaces.
|
|
||||||
payload = frames[0]
|
|
||||||
for key in (
|
|
||||||
'"prompt_tokens":60',
|
|
||||||
'"completion_tokens":40',
|
|
||||||
'"total_tokens":100',
|
|
||||||
'"cost_micros":50000',
|
|
||||||
):
|
|
||||||
assert key in payload.replace(" ", "")
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------ llm_bundle
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_llm_bundle_routes_negative_id_to_yaml_loader() -> None:
|
|
||||||
async def _run() -> tuple[Any, Any, str | None]:
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"app.tasks.chat.streaming.flows.shared.llm_bundle.load_global_llm_config_by_id",
|
|
||||||
return_value=None,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
return await load_llm_bundle(
|
|
||||||
session=AsyncMock(), # type: ignore[arg-type]
|
|
||||||
config_id=-1,
|
|
||||||
search_space_id=7,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm, agent_config, error = asyncio.run(_run())
|
|
||||||
assert llm is None
|
|
||||||
assert agent_config is None
|
|
||||||
assert error is not None and "id -1" in error
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_llm_bundle_routes_nonnegative_id_to_db_loader() -> None:
|
|
||||||
async def _run() -> tuple[Any, Any, str | None]:
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"app.tasks.chat.streaming.flows.shared.llm_bundle.load_agent_config",
|
|
||||||
new=AsyncMock(return_value=None),
|
|
||||||
),
|
|
||||||
):
|
|
||||||
return await load_llm_bundle(
|
|
||||||
session=AsyncMock(), # type: ignore[arg-type]
|
|
||||||
config_id=12,
|
|
||||||
search_space_id=7,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm, agent_config, error = asyncio.run(_run())
|
|
||||||
assert llm is None
|
|
||||||
assert agent_config is None
|
|
||||||
assert error is not None and "id 12" in error
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------- premium quota
|
|
||||||
|
|
||||||
|
|
||||||
def test_needs_premium_quota_requires_user_and_premium_flag() -> None:
|
|
||||||
class _AgentConfig:
|
|
||||||
is_premium = True
|
|
||||||
|
|
||||||
class _NonPremium:
|
|
||||||
is_premium = False
|
|
||||||
|
|
||||||
assert needs_premium_quota(_AgentConfig(), "user-1") is True # type: ignore[arg-type]
|
|
||||||
assert needs_premium_quota(_AgentConfig(), None) is False # type: ignore[arg-type]
|
|
||||||
assert needs_premium_quota(_NonPremium(), "user-1") is False # type: ignore[arg-type]
|
|
||||||
assert needs_premium_quota(None, "user-1") is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_premium_reservation_dataclass_shape() -> None:
|
|
||||||
# Sanity: the dataclass exists and carries the fields the orchestrator uses.
|
|
||||||
r = PremiumReservation(request_id="abc", reserved_micros=100, allowed=True)
|
|
||||||
assert r.request_id == "abc"
|
|
||||||
assert r.reserved_micros == 100
|
|
||||||
assert r.allowed is True
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------- rate-limit guard
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"first_event_seen, recovered, requested_id, current_id, expected",
|
|
||||||
[
|
|
||||||
(False, False, 0, -1, True),
|
|
||||||
# Already recovered: no second pass.
|
|
||||||
(False, True, 0, -1, False),
|
|
||||||
# User explicitly picked a config: don't silently switch.
|
|
||||||
(False, False, 5, -1, False),
|
|
||||||
# Already on a database-backed (positive) id.
|
|
||||||
(False, False, 0, 7, False),
|
|
||||||
# User has already seen output: silent rebuild not possible.
|
|
||||||
(True, False, 0, -1, False),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_can_recover_provider_rate_limit_truth_table(
|
|
||||||
first_event_seen: bool,
|
|
||||||
recovered: bool,
|
|
||||||
requested_id: int,
|
|
||||||
current_id: int,
|
|
||||||
expected: bool,
|
|
||||||
) -> None:
|
|
||||||
# Use a known rate-limit-shaped exception so the helper's last condition
|
|
||||||
# is satisfied; the guard only short-circuits to False when one of the
|
|
||||||
# *other* preconditions fails.
|
|
||||||
exc = Exception('{"error":{"type":"rate_limit_error","message":"slow"}}')
|
|
||||||
assert (
|
|
||||||
can_recover_provider_rate_limit(
|
|
||||||
exc,
|
|
||||||
first_event_seen=first_event_seen,
|
|
||||||
runtime_rate_limit_recovered=recovered,
|
|
||||||
requested_llm_config_id=requested_id,
|
|
||||||
current_llm_config_id=current_id,
|
|
||||||
)
|
|
||||||
is expected
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_can_recover_provider_rate_limit_rejects_non_rate_limit_exception() -> None:
|
|
||||||
assert (
|
|
||||||
can_recover_provider_rate_limit(
|
|
||||||
ValueError("not a rate limit"),
|
|
||||||
first_event_seen=False,
|
|
||||||
runtime_rate_limit_recovered=False,
|
|
||||||
requested_llm_config_id=0,
|
|
||||||
current_llm_config_id=-1,
|
|
||||||
)
|
|
||||||
is False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------- persistence spawn
|
|
||||||
|
|
||||||
|
|
||||||
def test_spawn_set_ai_responding_bg_noop_without_user_id() -> None:
|
|
||||||
async def _run() -> set[asyncio.Task]:
|
|
||||||
background: set[asyncio.Task] = set()
|
|
||||||
spawn_set_ai_responding_bg(chat_id=1, user_id=None, background_tasks=background)
|
|
||||||
return background
|
|
||||||
|
|
||||||
bg = asyncio.run(_run())
|
|
||||||
assert bg == set()
|
|
||||||
|
|
||||||
|
|
||||||
def test_spawn_persist_user_task_registers_and_self_unregisters() -> None:
|
|
||||||
async def _run() -> tuple[int, int]:
|
|
||||||
background: set[asyncio.Task] = set()
|
|
||||||
with patch(
|
|
||||||
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_user_turn",
|
|
||||||
new=AsyncMock(return_value=99),
|
|
||||||
):
|
|
||||||
task = spawn_persist_user_task(
|
|
||||||
chat_id=1,
|
|
||||||
user_id="u",
|
|
||||||
turn_id="t",
|
|
||||||
user_query="hi",
|
|
||||||
user_image_data_urls=None,
|
|
||||||
mentioned_documents=None,
|
|
||||||
background_tasks=background,
|
|
||||||
)
|
|
||||||
size_before_await = len(background)
|
|
||||||
result = await asyncio.shield(task)
|
|
||||||
# Give the done-callback one event-loop tick to run.
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
return size_before_await, result # type: ignore[return-value]
|
|
||||||
|
|
||||||
size_before, result = asyncio.run(_run())
|
|
||||||
assert size_before == 1
|
|
||||||
assert result == 99
|
|
||||||
|
|
||||||
|
|
||||||
def test_spawn_persist_assistant_shell_task_registers() -> None:
|
|
||||||
async def _run() -> int | None:
|
|
||||||
background: set[asyncio.Task] = set()
|
|
||||||
with patch(
|
|
||||||
"app.tasks.chat.streaming.flows.new_chat.persistence_spawn.persist_assistant_shell",
|
|
||||||
new=AsyncMock(return_value=42),
|
|
||||||
):
|
|
||||||
task = spawn_persist_assistant_shell_task(
|
|
||||||
chat_id=1,
|
|
||||||
user_id="u",
|
|
||||||
turn_id="t",
|
|
||||||
background_tasks=background,
|
|
||||||
)
|
|
||||||
return await asyncio.shield(task)
|
|
||||||
|
|
||||||
assert asyncio.run(_run()) == 42
|
|
||||||
|
|
||||||
|
|
||||||
def test_await_persist_task_returns_none_on_failure() -> None:
|
|
||||||
async def _run() -> int | None:
|
|
||||||
async def _boom() -> int:
|
|
||||||
raise RuntimeError("DB down")
|
|
||||||
|
|
||||||
task = asyncio.create_task(_boom())
|
|
||||||
return await await_persist_task(
|
|
||||||
task,
|
|
||||||
chat_id=1,
|
|
||||||
turn_id="t",
|
|
||||||
log_label="parity-failure",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert asyncio.run(_run()) is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_await_persist_task_returns_none_for_none_input() -> None:
|
|
||||||
async def _run() -> int | None:
|
|
||||||
return await await_persist_task(
|
|
||||||
None,
|
|
||||||
chat_id=1,
|
|
||||||
turn_id="t",
|
|
||||||
log_label="parity-none",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert asyncio.run(_run()) is None
|
|
||||||
|
|
@ -1,240 +0,0 @@
|
||||||
"""Pin Stage 1 extractions as faithful copies of the old helpers.
|
|
||||||
|
|
||||||
Extractions under ``app.tasks.chat.streaming`` are compared to
|
|
||||||
``app.tasks.chat.stream_new_chat`` helpers.
|
|
||||||
For each Stage 1 extraction we assert the new function returns the same
|
|
||||||
output as the old one for a representative input set. The moment the
|
|
||||||
two diverge - intentionally or otherwise - this file fails loudly so
|
|
||||||
the divergence is reviewed rather than shipped silently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.agents.shared.errors import BusyError
|
|
||||||
from app.agents.shared.middleware.busy_mutex import request_cancel, reset_cancel
|
|
||||||
from app.tasks.chat.stream_new_chat import (
|
|
||||||
_classify_stream_exception as old_classify,
|
|
||||||
_emit_stream_terminal_error as old_emit_terminal_error,
|
|
||||||
_extract_chunk_parts as old_extract_chunk_parts,
|
|
||||||
_extract_resolved_file_path as old_extract_resolved_file_path,
|
|
||||||
_tool_output_has_error as old_tool_output_has_error,
|
|
||||||
_tool_output_to_text as old_tool_output_to_text,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.errors.classifier import (
|
|
||||||
classify_stream_exception as new_classify,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.errors.emitter import (
|
|
||||||
emit_stream_terminal_error as new_emit_terminal_error,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.helpers.chunk_parts import (
|
|
||||||
extract_chunk_parts as new_extract_chunk_parts,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.helpers.tool_output import (
|
|
||||||
extract_resolved_file_path as new_extract_resolved_file_path,
|
|
||||||
tool_output_has_error as new_tool_output_has_error,
|
|
||||||
tool_output_to_text as new_tool_output_to_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------- chunk parts
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class _Chunk:
|
|
||||||
content: Any = ""
|
|
||||||
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
||||||
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
_CHUNK_CASES: list[Any] = [
|
|
||||||
None,
|
|
||||||
_Chunk(content=""),
|
|
||||||
_Chunk(content="hello"),
|
|
||||||
_Chunk(content=42), # invalid type, defensively coerced to empty
|
|
||||||
_Chunk(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "Hello "},
|
|
||||||
{"type": "text", "text": "world"},
|
|
||||||
]
|
|
||||||
),
|
|
||||||
_Chunk(
|
|
||||||
content=[
|
|
||||||
{"type": "reasoning", "reasoning": "hmm "},
|
|
||||||
{"type": "reasoning", "text": "still"},
|
|
||||||
{"type": "text", "text": "answer"},
|
|
||||||
]
|
|
||||||
),
|
|
||||||
_Chunk(
|
|
||||||
content=[
|
|
||||||
{"type": "tool_call_chunk", "id": "c1", "name": "x", "args": "{"},
|
|
||||||
{"type": "tool_use", "id": "c2", "name": "y"},
|
|
||||||
{"type": "image_url", "url": "ignored"},
|
|
||||||
]
|
|
||||||
),
|
|
||||||
_Chunk(
|
|
||||||
content="visible",
|
|
||||||
additional_kwargs={"reasoning_content": "private"},
|
|
||||||
),
|
|
||||||
_Chunk(
|
|
||||||
tool_call_chunks=[
|
|
||||||
{"id": None, "name": None, "args": '{"a":1}', "index": 0},
|
|
||||||
{"id": "c", "name": "n", "args": "}", "index": 0},
|
|
||||||
]
|
|
||||||
),
|
|
||||||
_Chunk(
|
|
||||||
content=[{"type": "tool_call_chunk", "id": "from-block", "name": "x"}],
|
|
||||||
tool_call_chunks=[{"id": "from-attr", "name": "y"}],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("chunk", _CHUNK_CASES)
|
|
||||||
def test_extract_chunk_parts_matches_old_implementation(chunk: Any) -> None:
|
|
||||||
assert new_extract_chunk_parts(chunk) == old_extract_chunk_parts(chunk)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------- error classifier
|
|
||||||
|
|
||||||
|
|
||||||
def _classify_cases() -> list[Exception]:
|
|
||||||
"""Inputs that the FE depends on being mapped to specific error codes."""
|
|
||||||
return [
|
|
||||||
Exception("totally generic error"),
|
|
||||||
Exception('{"error":{"type":"rate_limit_error","message":"slow down"}}'),
|
|
||||||
Exception(
|
|
||||||
'OpenrouterException - {"error":{"message":"Provider returned error",'
|
|
||||||
'"code":429}}'
|
|
||||||
),
|
|
||||||
BusyError(request_id="thread-busy-parity"),
|
|
||||||
Exception("Thread is busy with another request"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("exc", _classify_cases())
|
|
||||||
def test_classify_stream_exception_matches_old_implementation(
|
|
||||||
exc: Exception,
|
|
||||||
) -> None:
|
|
||||||
new = new_classify(exc, flow_label="parity-test")
|
|
||||||
old = old_classify(exc, flow_label="parity-test")
|
|
||||||
# Strip the wall-clock retry timestamp before comparing — both
|
|
||||||
# implementations call ``time.time()`` independently and the call
|
|
||||||
# order is enough to differ by 1 ms in practice. Every other field
|
|
||||||
# in the tuple must match exactly.
|
|
||||||
new_extra = dict(new[5]) if isinstance(new[5], dict) else new[5]
|
|
||||||
old_extra = dict(old[5]) if isinstance(old[5], dict) else old[5]
|
|
||||||
if isinstance(new_extra, dict) and isinstance(old_extra, dict):
|
|
||||||
new_extra.pop("retry_after_at", None)
|
|
||||||
old_extra.pop("retry_after_at", None)
|
|
||||||
assert new[:5] == old[:5]
|
|
||||||
assert new_extra == old_extra
|
|
||||||
|
|
||||||
|
|
||||||
def test_classify_turn_cancelling_branch_parity() -> None:
|
|
||||||
"""The TURN_CANCELLING branch reads cancel state for the busy thread id;
|
|
||||||
both implementations must agree on retry-window semantics, not just the
|
|
||||||
plain THREAD_BUSY code."""
|
|
||||||
thread_id = "parity-cancelling-thread"
|
|
||||||
reset_cancel(thread_id)
|
|
||||||
request_cancel(thread_id)
|
|
||||||
exc = BusyError(request_id=thread_id)
|
|
||||||
new = new_classify(exc, flow_label="parity-test")
|
|
||||||
old = old_classify(exc, flow_label="parity-test")
|
|
||||||
assert new[0] == old[0] == "thread_busy"
|
|
||||||
assert new[1] == old[1] == "TURN_CANCELLING"
|
|
||||||
assert isinstance(new[5], dict) and isinstance(old[5], dict)
|
|
||||||
assert new[5]["retry_after_ms"] == old[5]["retry_after_ms"]
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------ terminal emitter
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeStreamingService:
|
|
||||||
"""Duck-types ``format_error`` for both old and new emitters."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.calls: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
def format_error(
|
|
||||||
self, message: str, *, error_code: str, extra: dict[str, Any] | None = None
|
|
||||||
) -> str:
|
|
||||||
self.calls.append(
|
|
||||||
{"message": message, "error_code": error_code, "extra": extra}
|
|
||||||
)
|
|
||||||
return f'data: {{"type":"error","errorText":"{message}"}}\n\n'
|
|
||||||
|
|
||||||
|
|
||||||
def test_emit_stream_terminal_error_matches_old_output_and_logs(caplog) -> None:
|
|
||||||
"""The new emitter must produce the same SSE frame and log the same
|
|
||||||
structured payload as the old one for the same arguments."""
|
|
||||||
args: dict[str, Any] = {
|
|
||||||
"flow": "new",
|
|
||||||
"request_id": "req-parity",
|
|
||||||
"thread_id": 7,
|
|
||||||
"search_space_id": 9,
|
|
||||||
"user_id": "user-parity",
|
|
||||||
"message": "boom",
|
|
||||||
"error_kind": "server_error",
|
|
||||||
"error_code": "SERVER_ERROR",
|
|
||||||
"severity": "error",
|
|
||||||
"is_expected": False,
|
|
||||||
"extra": {"foo": "bar"},
|
|
||||||
}
|
|
||||||
|
|
||||||
new_svc = _FakeStreamingService()
|
|
||||||
old_svc = _FakeStreamingService()
|
|
||||||
|
|
||||||
with caplog.at_level(logging.ERROR):
|
|
||||||
new_frame = new_emit_terminal_error(streaming_service=new_svc, **args)
|
|
||||||
old_frame = old_emit_terminal_error(streaming_service=old_svc, **args)
|
|
||||||
|
|
||||||
assert new_frame == old_frame
|
|
||||||
assert new_svc.calls == old_svc.calls
|
|
||||||
chat_error_records = [
|
|
||||||
r for r in caplog.records if "[chat_stream_error]" in r.message
|
|
||||||
]
|
|
||||||
# One log line per emit call (two emits -> two records).
|
|
||||||
assert len(chat_error_records) == 2
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------- tool output
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_output_helpers_match_old_implementation() -> None:
|
|
||||||
samples: list[Any] = [
|
|
||||||
{"result": "ok"},
|
|
||||||
{"error": "bad"},
|
|
||||||
{"result": "Error: x"},
|
|
||||||
"Error: plain",
|
|
||||||
"fine",
|
|
||||||
{"nested": {"a": 1}},
|
|
||||||
]
|
|
||||||
for s in samples:
|
|
||||||
assert new_tool_output_to_text(s) == old_tool_output_to_text(s)
|
|
||||||
assert new_tool_output_has_error(s) == old_tool_output_has_error(s)
|
|
||||||
|
|
||||||
assert new_extract_resolved_file_path(
|
|
||||||
tool_name="write_file",
|
|
||||||
tool_output={"path": " /tmp/x "},
|
|
||||||
tool_input=None,
|
|
||||||
) == old_extract_resolved_file_path(
|
|
||||||
tool_name="write_file",
|
|
||||||
tool_output={"path": " /tmp/x "},
|
|
||||||
tool_input=None,
|
|
||||||
)
|
|
||||||
assert new_extract_resolved_file_path(
|
|
||||||
tool_name="write_file",
|
|
||||||
tool_output={},
|
|
||||||
tool_input={"file_path": " /fallback "},
|
|
||||||
) == old_extract_resolved_file_path(
|
|
||||||
tool_name="write_file",
|
|
||||||
tool_output={},
|
|
||||||
tool_input={"file_path": " /fallback "},
|
|
||||||
)
|
|
||||||
|
|
@ -1,241 +0,0 @@
|
||||||
"""Parity tests for Stage 2 extractions (tool matching, thinking step, custom events)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.tasks.chat.stream_new_chat import _legacy_match_lc_id as old_legacy_match
|
|
||||||
from app.tasks.chat.streaming.handlers.custom_events import (
|
|
||||||
handle_action_log,
|
|
||||||
handle_action_log_updated,
|
|
||||||
handle_document_created,
|
|
||||||
handle_report_progress,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.helpers.tool_call_matching import (
|
|
||||||
match_buffered_langchain_tool_call_id as new_legacy_match,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.relay.state import AgentEventRelayState
|
|
||||||
from app.tasks.chat.streaming.relay.thinking_step_completion import (
|
|
||||||
complete_active_thinking_step,
|
|
||||||
)
|
|
||||||
from app.tasks.chat.streaming.relay.thinking_step_sse import emit_thinking_step_frame
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
|
||||||
|
|
||||||
|
|
||||||
def _copy_chunk_buffer(raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
||||||
return [dict(x) for x in raw]
|
|
||||||
|
|
||||||
|
|
||||||
def test_legacy_tool_call_match_matches_old_implementation() -> None:
|
|
||||||
cases: list[tuple[list[dict[str, Any]], str, str, dict[str, str]]] = [
|
|
||||||
(
|
|
||||||
[
|
|
||||||
{"name": "write_file", "id": "lc-a"},
|
|
||||||
{"name": "other", "id": "lc-b"},
|
|
||||||
],
|
|
||||||
"write_file",
|
|
||||||
"run-1",
|
|
||||||
{},
|
|
||||||
),
|
|
||||||
(
|
|
||||||
[{"name": "x", "id": None}, {"name": "y", "id": "lc-fallback"}],
|
|
||||||
"write_file",
|
|
||||||
"run-2",
|
|
||||||
{},
|
|
||||||
),
|
|
||||||
([{"name": "no_id"}], "write_file", "run-3", {}),
|
|
||||||
]
|
|
||||||
for chunks_template, tool_name, run_id, lc_map_seed in cases:
|
|
||||||
old_chunks = _copy_chunk_buffer(chunks_template)
|
|
||||||
new_chunks = _copy_chunk_buffer(chunks_template)
|
|
||||||
old_map = dict(lc_map_seed)
|
|
||||||
new_map = dict(lc_map_seed)
|
|
||||||
old_out = old_legacy_match(old_chunks, tool_name, run_id, old_map)
|
|
||||||
new_out = new_legacy_match(new_chunks, tool_name, run_id, new_map)
|
|
||||||
assert new_out == old_out
|
|
||||||
assert new_chunks == old_chunks
|
|
||||||
assert new_map == old_map
|
|
||||||
|
|
||||||
|
|
||||||
def test_emit_thinking_step_frame_invokes_builder_before_service() -> None:
|
|
||||||
order: list[str] = []
|
|
||||||
builder = MagicMock()
|
|
||||||
|
|
||||||
def on_ts(*args: Any, **kwargs: Any) -> None:
|
|
||||||
order.append("builder")
|
|
||||||
|
|
||||||
builder.on_thinking_step.side_effect = on_ts
|
|
||||||
|
|
||||||
svc = MagicMock()
|
|
||||||
|
|
||||||
def fmt(**kwargs: Any) -> str:
|
|
||||||
order.append("service")
|
|
||||||
return "frame"
|
|
||||||
|
|
||||||
svc.format_thinking_step.side_effect = fmt
|
|
||||||
|
|
||||||
out = emit_thinking_step_frame(
|
|
||||||
streaming_service=svc,
|
|
||||||
content_builder=builder,
|
|
||||||
step_id="thinking-1",
|
|
||||||
title="Working",
|
|
||||||
status="in_progress",
|
|
||||||
items=["a"],
|
|
||||||
)
|
|
||||||
assert out == "frame"
|
|
||||||
assert order == ["builder", "service"]
|
|
||||||
builder.on_thinking_step.assert_called_once()
|
|
||||||
svc.format_thinking_step.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_emit_thinking_step_frame_skips_builder_when_none() -> None:
|
|
||||||
svc = MagicMock(return_value="x")
|
|
||||||
svc.format_thinking_step.return_value = "frame"
|
|
||||||
assert (
|
|
||||||
emit_thinking_step_frame(
|
|
||||||
streaming_service=svc,
|
|
||||||
content_builder=None,
|
|
||||||
step_id="s",
|
|
||||||
title="t",
|
|
||||||
)
|
|
||||||
== "frame"
|
|
||||||
)
|
|
||||||
svc.format_thinking_step.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_complete_active_thinking_step_mirrors_closure_semantics() -> None:
|
|
||||||
svc = MagicMock()
|
|
||||||
svc.format_thinking_step.return_value = "done-frame"
|
|
||||||
completed: set[str] = set()
|
|
||||||
relay_state = AgentEventRelayState.for_invocation()
|
|
||||||
|
|
||||||
frame, new_id = complete_active_thinking_step(
|
|
||||||
state=relay_state,
|
|
||||||
streaming_service=svc,
|
|
||||||
content_builder=None,
|
|
||||||
last_active_step_id="thinking-1",
|
|
||||||
last_active_step_title="T",
|
|
||||||
last_active_step_items=["x"],
|
|
||||||
completed_step_ids=completed,
|
|
||||||
)
|
|
||||||
assert frame == "done-frame"
|
|
||||||
assert new_id is None
|
|
||||||
assert "thinking-1" in completed
|
|
||||||
|
|
||||||
frame2, id2 = complete_active_thinking_step(
|
|
||||||
state=relay_state,
|
|
||||||
streaming_service=svc,
|
|
||||||
content_builder=None,
|
|
||||||
last_active_step_id="thinking-1",
|
|
||||||
last_active_step_title="T",
|
|
||||||
last_active_step_items=[],
|
|
||||||
completed_step_ids=completed,
|
|
||||||
)
|
|
||||||
assert frame2 is None
|
|
||||||
assert id2 == "thinking-1"
|
|
||||||
|
|
||||||
|
|
||||||
def test_agent_event_relay_state_factory_matches_counter_rule() -> None:
|
|
||||||
s0 = AgentEventRelayState.for_invocation()
|
|
||||||
assert s0.thinking_step_counter == 0
|
|
||||||
assert s0.last_active_step_id is None
|
|
||||||
|
|
||||||
s1 = AgentEventRelayState.for_invocation(
|
|
||||||
initial_step_id="thinking-resume-1",
|
|
||||||
initial_step_title="Inherited",
|
|
||||||
initial_step_items=["Topic: X"],
|
|
||||||
)
|
|
||||||
assert s1.thinking_step_counter == 1
|
|
||||||
assert s1.last_active_step_id == "thinking-resume-1"
|
|
||||||
assert s1.next_thinking_step_id("thinking") == "thinking-2"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("phase", "message", "start_items", "expected_tail"),
|
|
||||||
[
|
|
||||||
(
|
|
||||||
"revising_section",
|
|
||||||
"progress line",
|
|
||||||
["Topic: Foo", "Modifying bar", "stale..."],
|
|
||||||
["Topic: Foo", "Modifying bar", "progress line"],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"other",
|
|
||||||
"phase msg",
|
|
||||||
["Topic: Foo", "old line"],
|
|
||||||
["Topic: Foo", "phase msg"],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_report_progress_items_match_reference(
|
|
||||||
phase: str,
|
|
||||||
message: str,
|
|
||||||
start_items: list[str],
|
|
||||||
expected_tail: list[str],
|
|
||||||
) -> None:
|
|
||||||
svc = MagicMock()
|
|
||||||
svc.format_thinking_step.return_value = "sse"
|
|
||||||
|
|
||||||
items = list(start_items)
|
|
||||||
frame, new_items = handle_report_progress(
|
|
||||||
{"message": message, "phase": phase},
|
|
||||||
last_active_step_id="step-1",
|
|
||||||
last_active_step_title="Report",
|
|
||||||
last_active_step_items=items,
|
|
||||||
streaming_service=svc,
|
|
||||||
content_builder=None,
|
|
||||||
)
|
|
||||||
assert frame == "sse"
|
|
||||||
assert new_items == expected_tail
|
|
||||||
kwargs = svc.format_thinking_step.call_args.kwargs
|
|
||||||
assert kwargs["items"] == expected_tail
|
|
||||||
|
|
||||||
|
|
||||||
def test_report_progress_noop_when_missing_message_or_step() -> None:
|
|
||||||
svc = MagicMock()
|
|
||||||
items = ["Topic: A"]
|
|
||||||
f1, i1 = handle_report_progress(
|
|
||||||
{"message": "", "phase": "x"},
|
|
||||||
last_active_step_id="s",
|
|
||||||
last_active_step_title="t",
|
|
||||||
last_active_step_items=items,
|
|
||||||
streaming_service=svc,
|
|
||||||
content_builder=None,
|
|
||||||
)
|
|
||||||
assert f1 is None and i1 is items
|
|
||||||
|
|
||||||
f2, i2 = handle_report_progress(
|
|
||||||
{"message": "m", "phase": "x"},
|
|
||||||
last_active_step_id=None,
|
|
||||||
last_active_step_title="t",
|
|
||||||
last_active_step_items=items,
|
|
||||||
streaming_service=svc,
|
|
||||||
content_builder=None,
|
|
||||||
)
|
|
||||||
assert f2 is None and i2 is items
|
|
||||||
|
|
||||||
|
|
||||||
def test_document_action_handlers_match_format_data_guards() -> None:
|
|
||||||
svc = MagicMock()
|
|
||||||
svc.format_data.return_value = "data-frame"
|
|
||||||
|
|
||||||
assert handle_document_created({}, streaming_service=svc) is None
|
|
||||||
assert handle_document_created({"id": 0}, streaming_service=svc) is None
|
|
||||||
handle_document_created({"id": 42, "title": "x"}, streaming_service=svc)
|
|
||||||
svc.format_data.assert_called_with(
|
|
||||||
"documents-updated", {"action": "created", "document": {"id": 42, "title": "x"}}
|
|
||||||
)
|
|
||||||
|
|
||||||
svc.reset_mock()
|
|
||||||
assert handle_action_log({"id": None}, streaming_service=svc) is None
|
|
||||||
handle_action_log({"id": 1}, streaming_service=svc)
|
|
||||||
svc.format_data.assert_called_once_with("action-log", {"id": 1})
|
|
||||||
|
|
||||||
svc.reset_mock()
|
|
||||||
assert handle_action_log_updated({"id": None}, streaming_service=svc) is None
|
|
||||||
handle_action_log_updated({"id": 2}, streaming_service=svc)
|
|
||||||
svc.format_data.assert_called_once_with("action-log-updated", {"id": 2})
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Unit tests for ``stream_new_chat._extract_chunk_parts``.
|
"""Unit tests for ``streaming.helpers.chunk_parts.extract_chunk_parts``.
|
||||||
|
|
||||||
Earlier versions only handled ``isinstance(chunk.content, str)`` and
|
Earlier versions only handled ``isinstance(chunk.content, str)`` and
|
||||||
silently dropped every other shape (Anthropic typed-block lists,
|
silently dropped every other shape (Anthropic typed-block lists,
|
||||||
|
|
@ -14,7 +14,9 @@ from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.tasks.chat.stream_new_chat import _extract_chunk_parts
|
from app.tasks.chat.streaming.helpers.chunk_parts import (
|
||||||
|
extract_chunk_parts as _extract_chunk_parts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ constructs a fresh :class:`AgentEventRelayState` with
|
||||||
``thinking_step_counter=0``), React renders sibling timeline rows with the
|
``thinking_step_counter=0``), React renders sibling timeline rows with the
|
||||||
same key — the warning the user reported in production.
|
same key — the warning the user reported in production.
|
||||||
|
|
||||||
The contract this module pins: each ``_stream_agent_events`` invocation must
|
The contract this module pins: each ``stream_agent_events`` invocation must
|
||||||
receive a ``step_prefix`` that is unique within the thread (we salt with the
|
receive a ``step_prefix`` that is unique within the thread (we salt with the
|
||||||
per-turn ``turn_id``), so the resulting step IDs across consecutive turns
|
per-turn ``turn_id``), so the resulting step IDs across consecutive turns
|
||||||
are always disjoint.
|
are always disjoint.
|
||||||
|
|
@ -23,10 +23,12 @@ from typing import Any
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
from app.tasks.chat.stream_new_chat import (
|
from app.tasks.chat.streaming.agent.event_loop import (
|
||||||
StreamResult,
|
stream_agent_events as _stream_agent_events,
|
||||||
_resume_step_prefix,
|
)
|
||||||
_stream_agent_events,
|
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||||
|
from app.tasks.chat.streaming.shared.utils import (
|
||||||
|
resume_step_prefix as _resume_step_prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""Unit tests for live tool-call argument streaming.
|
"""Unit tests for live tool-call argument streaming.
|
||||||
|
|
||||||
Pins the wire format that ``_stream_agent_events`` emits:
|
Pins the wire format that ``stream_agent_events`` emits:
|
||||||
``tool-input-start`` → ``tool-input-delta``... → ``tool-input-available`` →
|
``tool-input-start`` → ``tool-input-delta``... → ``tool-input-available`` →
|
||||||
``tool-output-available``, keyed consistently with LangChain ``tool_call.id``
|
``tool-output-available``, keyed consistently with LangChain ``tool_call.id``
|
||||||
when the model streams indexed chunks.
|
when the model streams indexed chunks.
|
||||||
|
|
@ -20,11 +20,13 @@ from typing import Any
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
from app.tasks.chat.stream_new_chat import (
|
from app.tasks.chat.streaming.agent.event_loop import (
|
||||||
StreamResult,
|
stream_agent_events as _stream_agent_events,
|
||||||
_legacy_match_lc_id,
|
|
||||||
_stream_agent_events,
|
|
||||||
)
|
)
|
||||||
|
from app.tasks.chat.streaming.helpers.tool_call_matching import (
|
||||||
|
match_buffered_langchain_tool_call_id as _legacy_match_lc_id,
|
||||||
|
)
|
||||||
|
from app.tasks.chat.streaming.shared.stream_result import StreamResult
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,438 +0,0 @@
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import app.tasks.chat.stream_new_chat as stream_new_chat_module
|
|
||||||
from app.agents.shared.errors import BusyError
|
|
||||||
from app.agents.shared.middleware.busy_mutex import request_cancel, reset_cancel
|
|
||||||
from app.tasks.chat.stream_new_chat import (
|
|
||||||
StreamResult,
|
|
||||||
_classify_stream_exception,
|
|
||||||
_contract_enforcement_active,
|
|
||||||
_evaluate_file_contract_outcome,
|
|
||||||
_extract_resolved_file_path,
|
|
||||||
_log_chat_stream_error,
|
|
||||||
_tool_output_has_error,
|
|
||||||
)
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_output_error_detection():
|
|
||||||
assert _tool_output_has_error("Error: failed to write file")
|
|
||||||
assert _tool_output_has_error({"error": "boom"})
|
|
||||||
assert _tool_output_has_error({"result": "Error: disk is full"})
|
|
||||||
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_resolved_file_path_prefers_structured_path():
|
|
||||||
assert (
|
|
||||||
_extract_resolved_file_path(
|
|
||||||
tool_name="write_file",
|
|
||||||
tool_output={"status": "completed", "path": "/docs/note.md"},
|
|
||||||
tool_input=None,
|
|
||||||
)
|
|
||||||
== "/docs/note.md"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_resolved_file_path_falls_back_to_tool_input():
|
|
||||||
assert (
|
|
||||||
_extract_resolved_file_path(
|
|
||||||
tool_name="edit_file",
|
|
||||||
tool_output={"status": "completed", "result": "updated"},
|
|
||||||
tool_input={"file_path": "/docs/edited.md"},
|
|
||||||
)
|
|
||||||
== "/docs/edited.md"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_resolved_file_path_does_not_parse_result_text():
|
|
||||||
assert (
|
|
||||||
_extract_resolved_file_path(
|
|
||||||
tool_name="write_file",
|
|
||||||
tool_output={"result": "Updated file /docs/from-text.md"},
|
|
||||||
tool_input=None,
|
|
||||||
)
|
|
||||||
is None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_file_write_contract_outcome_reasons():
|
|
||||||
result = StreamResult(intent_detected="file_write")
|
|
||||||
passed, reason = _evaluate_file_contract_outcome(result)
|
|
||||||
assert not passed
|
|
||||||
assert reason == "no_write_attempt"
|
|
||||||
|
|
||||||
result.write_attempted = True
|
|
||||||
passed, reason = _evaluate_file_contract_outcome(result)
|
|
||||||
assert not passed
|
|
||||||
assert reason == "write_failed"
|
|
||||||
|
|
||||||
result.write_succeeded = True
|
|
||||||
passed, reason = _evaluate_file_contract_outcome(result)
|
|
||||||
assert not passed
|
|
||||||
assert reason == "verification_failed"
|
|
||||||
|
|
||||||
result.verification_succeeded = True
|
|
||||||
passed, reason = _evaluate_file_contract_outcome(result)
|
|
||||||
assert passed
|
|
||||||
assert reason == ""
|
|
||||||
|
|
||||||
|
|
||||||
def test_contract_enforcement_local_only():
|
|
||||||
result = StreamResult(filesystem_mode="desktop_local_folder")
|
|
||||||
assert _contract_enforcement_active(result)
|
|
||||||
|
|
||||||
result.filesystem_mode = "cloud"
|
|
||||||
assert not _contract_enforcement_active(result)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_chat_stream_payload(record_message: str) -> dict:
|
|
||||||
prefix = "[chat_stream_error] "
|
|
||||||
assert record_message.startswith(prefix)
|
|
||||||
return json.loads(record_message[len(prefix) :])
|
|
||||||
|
|
||||||
|
|
||||||
def test_unified_chat_stream_error_log_schema(caplog):
|
|
||||||
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
|
|
||||||
_log_chat_stream_error(
|
|
||||||
flow="new",
|
|
||||||
error_kind="server_error",
|
|
||||||
error_code="SERVER_ERROR",
|
|
||||||
severity="warn",
|
|
||||||
is_expected=False,
|
|
||||||
request_id="req-123",
|
|
||||||
thread_id=101,
|
|
||||||
search_space_id=202,
|
|
||||||
user_id="user-1",
|
|
||||||
message="Error during chat: boom",
|
|
||||||
)
|
|
||||||
|
|
||||||
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
|
|
||||||
payload = _extract_chat_stream_payload(record.message)
|
|
||||||
|
|
||||||
required_keys = {
|
|
||||||
"event",
|
|
||||||
"flow",
|
|
||||||
"error_kind",
|
|
||||||
"error_code",
|
|
||||||
"severity",
|
|
||||||
"is_expected",
|
|
||||||
"request_id",
|
|
||||||
"thread_id",
|
|
||||||
"search_space_id",
|
|
||||||
"user_id",
|
|
||||||
"message",
|
|
||||||
}
|
|
||||||
assert required_keys.issubset(payload.keys())
|
|
||||||
assert payload["event"] == "chat_stream_error"
|
|
||||||
assert payload["flow"] == "new"
|
|
||||||
assert payload["error_code"] == "SERVER_ERROR"
|
|
||||||
|
|
||||||
|
|
||||||
def test_premium_quota_uses_unified_chat_stream_log_shape(caplog):
|
|
||||||
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
|
|
||||||
_log_chat_stream_error(
|
|
||||||
flow="resume",
|
|
||||||
error_kind="premium_quota_exhausted",
|
|
||||||
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
|
||||||
severity="info",
|
|
||||||
is_expected=True,
|
|
||||||
request_id="req-premium",
|
|
||||||
thread_id=303,
|
|
||||||
search_space_id=404,
|
|
||||||
user_id="user-2",
|
|
||||||
message="Buy more tokens to continue with this model, or switch to a free model",
|
|
||||||
extra={"auto_fallback": False},
|
|
||||||
)
|
|
||||||
|
|
||||||
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
|
|
||||||
payload = _extract_chat_stream_payload(record.message)
|
|
||||||
assert payload["event"] == "chat_stream_error"
|
|
||||||
assert payload["error_kind"] == "premium_quota_exhausted"
|
|
||||||
assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED"
|
|
||||||
assert payload["flow"] == "resume"
|
|
||||||
assert payload["is_expected"] is True
|
|
||||||
assert payload["auto_fallback"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_error_emission_keeps_machine_error_codes():
|
|
||||||
source = inspect.getsource(stream_new_chat_module)
|
|
||||||
format_error_calls = re.findall(r"format_error\(", source)
|
|
||||||
emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source))
|
|
||||||
|
|
||||||
# All stream paths should route through one shared terminal error emitter.
|
|
||||||
assert len(format_error_calls) == 1
|
|
||||||
assert {
|
|
||||||
"PREMIUM_QUOTA_EXHAUSTED",
|
|
||||||
"SERVER_ERROR",
|
|
||||||
}.issubset(emitted_error_codes)
|
|
||||||
assert 'flow: Literal["new", "regenerate"] = "new"' in source
|
|
||||||
assert "_emit_stream_terminal_error" in source
|
|
||||||
assert "flow=flow" in source
|
|
||||||
assert 'flow="resume"' in source
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_exception_classifies_rate_limited():
|
|
||||||
exc = Exception(
|
|
||||||
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
|
|
||||||
)
|
|
||||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
|
||||||
exc, flow_label="chat"
|
|
||||||
)
|
|
||||||
assert kind == "rate_limited"
|
|
||||||
assert code == "RATE_LIMITED"
|
|
||||||
assert severity == "warn"
|
|
||||||
assert is_expected is True
|
|
||||||
assert "temporarily rate-limited" in user_message
|
|
||||||
assert extra is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_exception_classifies_openrouter_429_payload():
|
|
||||||
exc = Exception(
|
|
||||||
'OpenrouterException - {"error":{"message":"Provider returned error","code":429,'
|
|
||||||
'"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}'
|
|
||||||
)
|
|
||||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
|
||||||
exc, flow_label="chat"
|
|
||||||
)
|
|
||||||
assert kind == "rate_limited"
|
|
||||||
assert code == "RATE_LIMITED"
|
|
||||||
assert severity == "warn"
|
|
||||||
assert is_expected is True
|
|
||||||
assert "temporarily rate-limited" in user_message
|
|
||||||
assert extra is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_exception_classifies_thread_busy():
|
|
||||||
exc = BusyError(request_id="thread-123")
|
|
||||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
|
||||||
exc, flow_label="chat"
|
|
||||||
)
|
|
||||||
assert kind == "thread_busy"
|
|
||||||
assert code == "THREAD_BUSY"
|
|
||||||
assert severity == "warn"
|
|
||||||
assert is_expected is True
|
|
||||||
assert "still finishing for this thread" in user_message
|
|
||||||
assert extra is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_exception_classifies_thread_busy_from_message():
|
|
||||||
exc = Exception("Thread is busy with another request")
|
|
||||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
|
||||||
exc, flow_label="chat"
|
|
||||||
)
|
|
||||||
assert kind == "thread_busy"
|
|
||||||
assert code == "THREAD_BUSY"
|
|
||||||
assert severity == "warn"
|
|
||||||
assert is_expected is True
|
|
||||||
assert "still finishing for this thread" in user_message
|
|
||||||
assert extra is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
|
|
||||||
thread_id = "thread-cancelling-1"
|
|
||||||
reset_cancel(thread_id)
|
|
||||||
request_cancel(thread_id)
|
|
||||||
exc = BusyError(request_id=thread_id)
|
|
||||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
|
||||||
exc, flow_label="chat"
|
|
||||||
)
|
|
||||||
assert kind == "thread_busy"
|
|
||||||
assert code == "TURN_CANCELLING"
|
|
||||||
assert severity == "info"
|
|
||||||
assert is_expected is True
|
|
||||||
assert "stopping" in user_message
|
|
||||||
assert isinstance(extra, dict)
|
|
||||||
assert "retry_after_ms" in extra
|
|
||||||
|
|
||||||
|
|
||||||
def test_premium_classification_is_error_code_driven():
|
|
||||||
classifier_path = (
|
|
||||||
Path(__file__).resolve().parents[3]
|
|
||||||
/ "surfsense_web/lib/chat/chat-error-classifier.ts"
|
|
||||||
)
|
|
||||||
source = classifier_path.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
assert "PREMIUM_KEYWORDS" not in source
|
|
||||||
assert "RATE_LIMIT_KEYWORDS" not in source
|
|
||||||
assert "normalized.includes(" not in source
|
|
||||||
assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook():
|
|
||||||
page_path = (
|
|
||||||
Path(__file__).resolve().parents[3]
|
|
||||||
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
|
|
||||||
)
|
|
||||||
source = page_path.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
assert "onPreAcceptFailure?: () => Promise<void>;" in source
|
|
||||||
assert "if (!accepted) {" in source
|
|
||||||
assert "await onPreAcceptFailure?.();" in source
|
|
||||||
assert "await onAcceptedStreamError?.();" in source
|
|
||||||
assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source
|
|
||||||
assert "setMessageDocumentsMap((prev) => {" in source
|
|
||||||
|
|
||||||
|
|
||||||
def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
|
|
||||||
user_message_path = (
|
|
||||||
Path(__file__).resolve().parents[3]
|
|
||||||
/ "surfsense_web/components/assistant-ui/user-message.tsx"
|
|
||||||
)
|
|
||||||
source = user_message_path.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
assert "Not sent. Edit and retry." not in source
|
|
||||||
assert "failed_pre_accept" not in source
|
|
||||||
|
|
||||||
|
|
||||||
def test_network_send_failures_use_unified_retry_toast_message():
|
|
||||||
classifier_path = (
|
|
||||||
Path(__file__).resolve().parents[3]
|
|
||||||
/ "surfsense_web/lib/chat/chat-error-classifier.ts"
|
|
||||||
)
|
|
||||||
classifier_source = classifier_path.read_text(encoding="utf-8")
|
|
||||||
request_errors_path = (
|
|
||||||
Path(__file__).resolve().parents[3]
|
|
||||||
/ "surfsense_web/lib/chat/chat-request-errors.ts"
|
|
||||||
)
|
|
||||||
request_errors_source = request_errors_path.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
assert '"send_failed_pre_accept"' in classifier_source
|
|
||||||
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
|
|
||||||
assert 'errorCode === "TURN_CANCELLING"' in classifier_source
|
|
||||||
assert "if (withCode.code) return withCode.code;" in classifier_source
|
|
||||||
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
|
|
||||||
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
|
|
||||||
assert "const passthroughCodes = new Set([" in request_errors_source
|
|
||||||
assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
|
|
||||||
assert '"THREAD_BUSY"' in request_errors_source
|
|
||||||
assert '"TURN_CANCELLING"' in request_errors_source
|
|
||||||
assert '"AUTH_EXPIRED"' in request_errors_source
|
|
||||||
assert '"UNAUTHORIZED"' in request_errors_source
|
|
||||||
assert '"RATE_LIMITED"' in request_errors_source
|
|
||||||
assert '"NETWORK_ERROR"' in request_errors_source
|
|
||||||
assert '"STREAM_PARSE_ERROR"' in request_errors_source
|
|
||||||
assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
|
|
||||||
assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
|
|
||||||
assert '"SERVER_ERROR"' in request_errors_source
|
|
||||||
assert "passthroughCodes.has(existingCode)" in request_errors_source
|
|
||||||
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
|
|
||||||
assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
|
|
||||||
assert "Failed to start chat. Please try again." not in classifier_source
|
|
||||||
|
|
||||||
|
|
||||||
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
|
|
||||||
page_path = (
|
|
||||||
Path(__file__).resolve().parents[3]
|
|
||||||
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
|
|
||||||
)
|
|
||||||
source = page_path.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
# Each flow tracks accepted boundary and passes it into shared terminal handling.
|
|
||||||
# The acceptance boundary is still meaningful post-refactor: it gates
|
|
||||||
# local-state cleanup (onPreAcceptFailure path) and lets the shared
|
|
||||||
# terminal handler distinguish pre-accept aborts from in-stream errors.
|
|
||||||
assert "let newAccepted = false;" in source
|
|
||||||
assert "let resumeAccepted = false;" in source
|
|
||||||
assert "let regenerateAccepted = false;" in source
|
|
||||||
assert "accepted: newAccepted," in source
|
|
||||||
assert "accepted: resumeAccepted," in source
|
|
||||||
assert "accepted: regenerateAccepted," in source
|
|
||||||
|
|
||||||
# NOTE: The FE-side persistence guards previously asserted here
|
|
||||||
# ("if (!resumeAccepted) return;", "if (!regenerateAccepted) return;",
|
|
||||||
# "if (newAccepted && !userPersisted) {") have been intentionally
|
|
||||||
# removed by the SSE-based message-id handshake refactor. Persistence
|
|
||||||
# is now server-authoritative: persist_user_turn / persist_assistant_shell
|
|
||||||
# run inside stream_new_chat / stream_resume_chat unconditionally and
|
|
||||||
# the FE consumes data-user-message-id / data-assistant-message-id
|
|
||||||
# SSE events to learn the canonical primary keys. There is therefore
|
|
||||||
# no FE call-site to guard, and the shared terminal handler relies
|
|
||||||
# purely on the `accepted` field above (forwarded to onAbort /
|
|
||||||
# onAcceptedStreamError) to drive UI cleanup. See
|
|
||||||
# tests/integration/chat/test_message_id_sse.py for the new
|
|
||||||
# cross-tier ID coherence guarantees.
|
|
||||||
|
|
||||||
# The TURN_CANCELLING / THREAD_BUSY retry plumbing is independent
|
|
||||||
# of the persistence refactor and must still exist on every
|
|
||||||
# start-stream fetch.
|
|
||||||
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
|
|
||||||
assert "computeFallbackTurnCancellingRetryDelay" in source
|
|
||||||
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
|
|
||||||
assert 'withMeta.errorCode === "THREAD_BUSY"' in source
|
|
||||||
assert "await fetchWithTurnCancellingRetry(() =>" in source
|
|
||||||
|
|
||||||
|
|
||||||
def test_cancel_active_turn_route_contract_exists():
|
|
||||||
routes_path = (
|
|
||||||
Path(__file__).resolve().parents[3]
|
|
||||||
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
|
||||||
)
|
|
||||||
source = routes_path.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
|
|
||||||
assert "response_model=CancelActiveTurnResponse" in source
|
|
||||||
assert 'status="cancelling",' in source
|
|
||||||
assert 'error_code="TURN_CANCELLING",' in source
|
|
||||||
assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
|
|
||||||
assert "retry_after_at=" in source
|
|
||||||
assert 'status="idle",' in source
|
|
||||||
assert 'error_code="NO_ACTIVE_TURN",' in source
|
|
||||||
|
|
||||||
|
|
||||||
def test_turn_status_route_contract_exists():
|
|
||||||
routes_path = (
|
|
||||||
Path(__file__).resolve().parents[3]
|
|
||||||
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
|
||||||
)
|
|
||||||
source = routes_path.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
|
|
||||||
assert "response_model=TurnStatusResponse" in source
|
|
||||||
assert "_build_turn_status_payload(thread_id)" in source
|
|
||||||
assert "Permission.CHATS_READ.value" in source
|
|
||||||
assert "_raise_if_thread_busy_for_start(" in source
|
|
||||||
|
|
||||||
|
|
||||||
def test_turn_cancelling_retry_policy_contract_exists():
|
|
||||||
routes_path = (
|
|
||||||
Path(__file__).resolve().parents[3]
|
|
||||||
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
|
||||||
)
|
|
||||||
source = routes_path.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
|
|
||||||
assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
|
|
||||||
assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
|
|
||||||
assert "def _compute_turn_cancelling_retry_delay(" in source
|
|
||||||
assert "retry-after-ms" in source
|
|
||||||
assert '"Retry-After"' in source
|
|
||||||
assert '"errorCode": "TURN_CANCELLING"' in source
|
|
||||||
|
|
||||||
|
|
||||||
def test_turn_status_sse_contract_exists():
|
|
||||||
stream_source = (
|
|
||||||
Path(__file__).resolve().parents[3]
|
|
||||||
/ "surfsense_backend/app/tasks/chat/stream_new_chat.py"
|
|
||||||
).read_text(encoding="utf-8")
|
|
||||||
state_source = (
|
|
||||||
Path(__file__).resolve().parents[3]
|
|
||||||
/ "surfsense_web/lib/chat/streaming-state.ts"
|
|
||||||
).read_text(encoding="utf-8")
|
|
||||||
pipeline_source = (
|
|
||||||
Path(__file__).resolve().parents[3]
|
|
||||||
/ "surfsense_web/lib/chat/stream-pipeline.ts"
|
|
||||||
).read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
assert '"turn-status"' in stream_source
|
|
||||||
assert '"status": "busy"' in stream_source
|
|
||||||
assert '"status": "idle"' in stream_source
|
|
||||||
assert 'type: "data-turn-status"' in state_source
|
|
||||||
assert 'case "data-turn-status":' in pipeline_source
|
|
||||||
assert "end_turn(str(chat_id))" in stream_source
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue