149 lines
5.4 KiB
Python
149 lines
5.4 KiB
Python
"""
|
|
Unit tests for transitive backend-error handling in the four Ollama-native
|
|
streaming generators (``/api/generate``, ``/api/chat``, ``/api/embeddings``,
|
|
``/api/embed``).
|
|
|
|
These reproduce the reported failure mode: a backend (nginx in front of ollama)
|
|
returns a 504 Gateway Time-out *while the response is being streamed*, so the
|
|
``ollama`` client raises ``ResponseError`` from inside the StreamingResponse
|
|
generator. Before the fix this escaped as an opaque "Exception in ASGI
|
|
application" traceback; now ``_handle_stream_error`` logs the endpoint/model and
|
|
emits a terminal Ollama-format ``{"error": ..., "status_code": ...}`` line.
|
|
|
|
No real backend required — the ollama client and routing are mocked.
|
|
"""
|
|
import json
|
|
from contextlib import ExitStack
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import httpx
|
|
import ollama
|
|
import openai
|
|
import pytest
|
|
|
|
from conftest import TEST_OLLAMA
|
|
|
|
pytestmark = pytest.mark.asyncio
|
|
|
|
|
|
# ── Fakes ─────────────────────────────────────────────────────────────────────
|
|
|
|
class _Chunk:
|
|
"""Minimal Ollama-native streaming chunk the generators can consume."""
|
|
prompt_eval_count = 0
|
|
eval_count = 0
|
|
done = False
|
|
message = None
|
|
response = None
|
|
done_reason = None
|
|
|
|
def model_dump_json(self):
|
|
return '{"model": "fake", "done": false}'
|
|
|
|
|
|
def _one_then_raise(exc):
|
|
"""Async generator: yield one valid chunk, then fail mid-stream."""
|
|
async def _gen():
|
|
yield _Chunk()
|
|
raise exc
|
|
return _gen()
|
|
|
|
|
|
class _FakeAsyncClient:
|
|
"""Stand-in for ``ollama.AsyncClient`` that fails with ``exc``.
|
|
|
|
Streaming methods (chat/generate) fail *after* one chunk to mimic a
|
|
mid-stream 504; the embedding methods fail on the initial await.
|
|
"""
|
|
def __init__(self, exc, *args, **kwargs):
|
|
self._exc = exc
|
|
|
|
async def chat(self, **kwargs):
|
|
return _one_then_raise(self._exc)
|
|
|
|
async def generate(self, **kwargs):
|
|
return _one_then_raise(self._exc)
|
|
|
|
async def embeddings(self, **kwargs):
|
|
raise self._exc
|
|
|
|
async def embed(self, **kwargs):
|
|
raise self._exc
|
|
|
|
|
|
def _patches(exc, mark_unhealthy):
|
|
"""Patch routing + the ollama client so the native path hits ``exc``."""
|
|
stack = ExitStack()
|
|
stack.enter_context(
|
|
patch("api.ollama.choose_endpoint", AsyncMock(return_value=(TEST_OLLAMA, "fake")))
|
|
)
|
|
stack.enter_context(patch("api.ollama.is_openai_compatible", lambda ep: False))
|
|
stack.enter_context(patch("api.ollama.decrement_usage", AsyncMock()))
|
|
stack.enter_context(patch("api.ollama._mark_backend_unhealthy", mark_unhealthy))
|
|
stack.enter_context(
|
|
patch("api.ollama.ollama.AsyncClient", lambda *a, **k: _FakeAsyncClient(exc))
|
|
)
|
|
return stack
|
|
|
|
|
|
# Route → request payload. stream=True only matters for chat/generate.
|
|
_ROUTES = {
|
|
"/api/chat": {"model": "fake", "stream": True, "messages": [{"role": "user", "content": "hi"}]},
|
|
"/api/generate": {"model": "fake", "stream": True, "prompt": "hi"},
|
|
"/api/embeddings": {"model": "fake", "prompt": "hi"},
|
|
"/api/embed": {"model": "fake", "input": "hi"},
|
|
}
|
|
|
|
|
|
def _last_json_line(text):
|
|
lines = [l for l in text.strip().split("\n") if l.strip()]
|
|
assert lines, "expected at least one ndjson line in the response body"
|
|
return json.loads(lines[-1])
|
|
|
|
|
|
# ── Tests ─────────────────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.parametrize("route, payload", list(_ROUTES.items()))
|
|
async def test_504_surfaces_as_error_line(client, route, payload):
|
|
"""A 504 ResponseError becomes a terminal {"error", "status_code"} line."""
|
|
exc = ollama.ResponseError("<html>504 Gateway Time-out</html>", 504)
|
|
mark = AsyncMock()
|
|
with _patches(exc, mark):
|
|
resp = await client.post(route, json=payload)
|
|
|
|
# Streaming already started (or single-shot) → HTTP status is 200, the
|
|
# error is delivered in-band rather than as a 5xx crash.
|
|
assert resp.status_code == 200
|
|
err = _last_json_line(resp.text)
|
|
assert "error" in err
|
|
assert "504" in err["error"]
|
|
assert err["status_code"] == 504
|
|
# A plain 504 is not a connection-class failure → endpoint stays healthy.
|
|
mark.assert_not_called()
|
|
|
|
|
|
@pytest.mark.parametrize("route, payload", list(_ROUTES.items()))
|
|
async def test_no_asgi_500_on_backend_failure(client, route, payload):
|
|
"""The generator must never let the backend error escape as a 500."""
|
|
exc = ollama.ResponseError("boom", 502)
|
|
with _patches(exc, AsyncMock()):
|
|
resp = await client.post(route, json=payload)
|
|
assert resp.status_code == 200
|
|
assert resp.status_code != 500
|
|
|
|
|
|
async def test_connection_error_marks_backend_unhealthy(client):
|
|
"""A connection-class failure mid-stream marks (endpoint, model) unhealthy."""
|
|
exc = openai.APIConnectionError(request=httpx.Request("POST", "http://x"))
|
|
mark = AsyncMock()
|
|
with _patches(exc, mark):
|
|
resp = await client.post("/api/chat", json=_ROUTES["/api/chat"])
|
|
|
|
assert resp.status_code == 200
|
|
err = _last_json_line(resp.text)
|
|
assert "error" in err
|
|
mark.assert_awaited_once()
|
|
# Called with the routed endpoint + model.
|
|
called_ep, called_model = mark.await_args.args[0], mark.await_args.args[1]
|
|
assert called_ep == TEST_OLLAMA
|
|
assert called_model == "fake"
|