feat: add test for ollama stream errors
This commit is contained in:
parent
d3b2ee3047
commit
2dceece0d6
1 changed files with 149 additions and 0 deletions
149
test/test_stream_errors.py
Normal file
149
test/test_stream_errors.py
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
"""
|
||||
Unit tests for transitive backend-error handling in the four Ollama-native
|
||||
streaming generators (``/api/generate``, ``/api/chat``, ``/api/embeddings``,
|
||||
``/api/embed``).
|
||||
|
||||
These reproduce the reported failure mode: a backend (nginx in front of ollama)
|
||||
returns a 504 Gateway Time-out *while the response is being streamed*, so the
|
||||
``ollama`` client raises ``ResponseError`` from inside the StreamingResponse
|
||||
generator. Before the fix this escaped as an opaque "Exception in ASGI
|
||||
application" traceback; now ``_handle_stream_error`` logs the endpoint/model and
|
||||
emits a terminal Ollama-format ``{"error": ..., "status_code": ...}`` line.
|
||||
|
||||
No real backend required — the ollama client and routing are mocked.
|
||||
"""
|
||||
import json
|
||||
from contextlib import ExitStack
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import ollama
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from conftest import TEST_OLLAMA
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
# ── Fakes ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
class _Chunk:
|
||||
"""Minimal Ollama-native streaming chunk the generators can consume."""
|
||||
prompt_eval_count = 0
|
||||
eval_count = 0
|
||||
done = False
|
||||
message = None
|
||||
response = None
|
||||
done_reason = None
|
||||
|
||||
def model_dump_json(self):
|
||||
return '{"model": "fake", "done": false}'
|
||||
|
||||
|
||||
def _one_then_raise(exc):
|
||||
"""Async generator: yield one valid chunk, then fail mid-stream."""
|
||||
async def _gen():
|
||||
yield _Chunk()
|
||||
raise exc
|
||||
return _gen()
|
||||
|
||||
|
||||
class _FakeAsyncClient:
|
||||
"""Stand-in for ``ollama.AsyncClient`` that fails with ``exc``.
|
||||
|
||||
Streaming methods (chat/generate) fail *after* one chunk to mimic a
|
||||
mid-stream 504; the embedding methods fail on the initial await.
|
||||
"""
|
||||
def __init__(self, exc, *args, **kwargs):
|
||||
self._exc = exc
|
||||
|
||||
async def chat(self, **kwargs):
|
||||
return _one_then_raise(self._exc)
|
||||
|
||||
async def generate(self, **kwargs):
|
||||
return _one_then_raise(self._exc)
|
||||
|
||||
async def embeddings(self, **kwargs):
|
||||
raise self._exc
|
||||
|
||||
async def embed(self, **kwargs):
|
||||
raise self._exc
|
||||
|
||||
|
||||
def _patches(exc, mark_unhealthy):
|
||||
"""Patch routing + the ollama client so the native path hits ``exc``."""
|
||||
stack = ExitStack()
|
||||
stack.enter_context(
|
||||
patch("api.ollama.choose_endpoint", AsyncMock(return_value=(TEST_OLLAMA, "fake")))
|
||||
)
|
||||
stack.enter_context(patch("api.ollama.is_openai_compatible", lambda ep: False))
|
||||
stack.enter_context(patch("api.ollama.decrement_usage", AsyncMock()))
|
||||
stack.enter_context(patch("api.ollama._mark_backend_unhealthy", mark_unhealthy))
|
||||
stack.enter_context(
|
||||
patch("api.ollama.ollama.AsyncClient", lambda *a, **k: _FakeAsyncClient(exc))
|
||||
)
|
||||
return stack
|
||||
|
||||
|
||||
# Route → request payload. stream=True only matters for chat/generate.
|
||||
_ROUTES = {
|
||||
"/api/chat": {"model": "fake", "stream": True, "messages": [{"role": "user", "content": "hi"}]},
|
||||
"/api/generate": {"model": "fake", "stream": True, "prompt": "hi"},
|
||||
"/api/embeddings": {"model": "fake", "prompt": "hi"},
|
||||
"/api/embed": {"model": "fake", "input": "hi"},
|
||||
}
|
||||
|
||||
|
||||
def _last_json_line(text):
|
||||
lines = [l for l in text.strip().split("\n") if l.strip()]
|
||||
assert lines, "expected at least one ndjson line in the response body"
|
||||
return json.loads(lines[-1])
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.parametrize("route, payload", list(_ROUTES.items()))
|
||||
async def test_504_surfaces_as_error_line(client, route, payload):
|
||||
"""A 504 ResponseError becomes a terminal {"error", "status_code"} line."""
|
||||
exc = ollama.ResponseError("<html>504 Gateway Time-out</html>", 504)
|
||||
mark = AsyncMock()
|
||||
with _patches(exc, mark):
|
||||
resp = await client.post(route, json=payload)
|
||||
|
||||
# Streaming already started (or single-shot) → HTTP status is 200, the
|
||||
# error is delivered in-band rather than as a 5xx crash.
|
||||
assert resp.status_code == 200
|
||||
err = _last_json_line(resp.text)
|
||||
assert "error" in err
|
||||
assert "504" in err["error"]
|
||||
assert err["status_code"] == 504
|
||||
# A plain 504 is not a connection-class failure → endpoint stays healthy.
|
||||
mark.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("route, payload", list(_ROUTES.items()))
|
||||
async def test_no_asgi_500_on_backend_failure(client, route, payload):
|
||||
"""The generator must never let the backend error escape as a 500."""
|
||||
exc = ollama.ResponseError("boom", 502)
|
||||
with _patches(exc, AsyncMock()):
|
||||
resp = await client.post(route, json=payload)
|
||||
assert resp.status_code == 200
|
||||
assert resp.status_code != 500
|
||||
|
||||
|
||||
async def test_connection_error_marks_backend_unhealthy(client):
|
||||
"""A connection-class failure mid-stream marks (endpoint, model) unhealthy."""
|
||||
exc = openai.APIConnectionError(request=httpx.Request("POST", "http://x"))
|
||||
mark = AsyncMock()
|
||||
with _patches(exc, mark):
|
||||
resp = await client.post("/api/chat", json=_ROUTES["/api/chat"])
|
||||
|
||||
assert resp.status_code == 200
|
||||
err = _last_json_line(resp.text)
|
||||
assert "error" in err
|
||||
mark.assert_awaited_once()
|
||||
# Called with the routed endpoint + model.
|
||||
called_ep, called_model = mark.await_args.args[0], mark.await_args.args[1]
|
||||
assert called_ep == TEST_OLLAMA
|
||||
assert called_model == "fake"
|
||||
Loading…
Add table
Add a link
Reference in a new issue