nomyo-router/test/test_openai_proxies.py

181 lines
7.7 KiB
Python

"""Cache-hit short-circuit tests for the OpenAI-compatible proxy routes.
These tests verify that when the LLM cache reports a hit, the route returns
the cached payload *without* selecting an endpoint or contacting any backend.
"""
from unittest.mock import AsyncMock, patch
import orjson
import pytest
from fastapi import HTTPException
import router
_BYPASS = HTTPException(status_code=599, detail="bypassed")
class _FakeCache:
"""Minimal stand-in for cache.LLMCache.get_chat."""
def __init__(self, response_bytes: bytes | None):
self._resp = response_bytes
self.calls: list[tuple] = []
async def get_chat(self, route, model, messages):
self.calls.append((route, model, messages))
return self._resp
@pytest.fixture
def cache_hit_payload():
return orjson.dumps({
"id": "cmpl-xyz",
"created": 1,
"model": "test-model",
"choices": [{"message": {"role": "assistant", "content": "from-cache"}}],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
})
# ──────────────────────────────────────────────────────────────────────────────
# /v1/chat/completions
# ──────────────────────────────────────────────────────────────────────────────
class TestOpenAIChatCompletionsCacheHit:
async def test_nonstream_cache_hit_returns_cached_json(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload)
# Patch the route's references to both helpers — they're imported by name
# into router's namespace at module load time.
with (
patch.object(router, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "test-model",
"messages": [{"role": "user", "content": "ping"}],
"stream": False,
"nomyo": {"cache": True},
},
)
assert resp.status_code == 200
# Body is streamed; collect it
body = resp.content
parsed = orjson.loads(body)
assert parsed["choices"][0]["message"]["content"] == "from-cache"
assert fake.calls and fake.calls[0][0] == "openai_chat"
async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload)
with (
patch.object(router, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "test-model",
"messages": [{"role": "user", "content": "ping"}],
"stream": True,
"nomyo": {"cache": True},
},
)
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
text = resp.content.decode()
# First SSE frame contains the cached content as a delta
first_frame = text.split("\n\n")[0]
assert first_frame.startswith("data: ")
chunk = orjson.loads(first_frame[len("data: "):])
assert chunk["choices"][0]["delta"]["content"] == "from-cache"
# Stream is terminated with [DONE]
assert "data: [DONE]" in text
async def test_cache_disabled_in_payload_bypasses_cache_check(self, client):
"""When nomyo.cache=False, get_chat is never called even if a cache exists."""
fake = _FakeCache(b"") # has a response, but should never be consulted
with (
patch.object(router, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint",
AsyncMock(side_effect=_BYPASS)),
):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "m",
"messages": [{"role": "user", "content": "hi"}],
"nomyo": {"cache": False},
},
)
# Got past the cache short-circuit → endpoint selection invoked
assert resp.status_code == 599
assert fake.calls == []
async def test_no_cache_configured_bypasses_cache_check(self, client):
"""get_llm_cache() returning None should not break the route."""
with (
patch.object(router, "get_llm_cache", return_value=None),
patch.object(router, "choose_endpoint",
AsyncMock(side_effect=_BYPASS)),
):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "m",
"messages": [{"role": "user", "content": "hi"}],
"nomyo": {"cache": True},
},
)
assert resp.status_code == 599
# ──────────────────────────────────────────────────────────────────────────────
# /v1/completions
# ──────────────────────────────────────────────────────────────────────────────
class TestOpenAICompletionsCacheHit:
async def test_nonstream_cache_hit(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload)
with (
patch.object(router, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
):
resp = await client.post(
"/v1/completions",
json={
"model": "test-model",
"prompt": "Tell me a joke",
"stream": False,
"nomyo": {"cache": True},
},
)
assert resp.status_code == 200
# Prompt-style cache lookup is namespaced under "openai_completions"
assert fake.calls[0][0] == "openai_completions"
# Cache lookup receives the prompt as a single user message
cached_msgs = fake.calls[0][2]
assert cached_msgs == [{"role": "user", "content": "Tell me a joke"}]
async def test_stream_cache_hit(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload)
with (
patch.object(router, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
):
resp = await client.post(
"/v1/completions",
json={
"model": "test-model",
"prompt": "What is 2+2?",
"stream": True,
"nomyo": {"cache": True},
},
)
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
assert "data: [DONE]" in resp.content.decode()