182 lines
7.7 KiB
Python
182 lines
7.7 KiB
Python
|
|
"""Cache-hit short-circuit tests for the OpenAI-compatible proxy routes.
|
||
|
|
|
||
|
|
These tests verify that when the LLM cache reports a hit, the route returns
|
||
|
|
the cached payload *without* selecting an endpoint or contacting any backend.
|
||
|
|
"""
|
||
|
|
from unittest.mock import AsyncMock, patch
|
||
|
|
|
||
|
|
import orjson
|
||
|
|
import pytest
|
||
|
|
from fastapi import HTTPException
|
||
|
|
|
||
|
|
import router
|
||
|
|
|
||
|
|
|
||
|
|
_BYPASS = HTTPException(status_code=599, detail="bypassed")
|
||
|
|
|
||
|
|
|
||
|
|
class _FakeCache:
|
||
|
|
"""Minimal stand-in for cache.LLMCache.get_chat."""
|
||
|
|
def __init__(self, response_bytes: bytes | None):
|
||
|
|
self._resp = response_bytes
|
||
|
|
self.calls: list[tuple] = []
|
||
|
|
|
||
|
|
async def get_chat(self, route, model, messages):
|
||
|
|
self.calls.append((route, model, messages))
|
||
|
|
return self._resp
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def cache_hit_payload():
|
||
|
|
return orjson.dumps({
|
||
|
|
"id": "cmpl-xyz",
|
||
|
|
"created": 1,
|
||
|
|
"model": "test-model",
|
||
|
|
"choices": [{"message": {"role": "assistant", "content": "from-cache"}}],
|
||
|
|
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||
|
|
})
|
||
|
|
|
||
|
|
|
||
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
||
|
|
# /v1/chat/completions
|
||
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestOpenAIChatCompletionsCacheHit:
|
||
|
|
async def test_nonstream_cache_hit_returns_cached_json(self, client, cache_hit_payload):
|
||
|
|
fake = _FakeCache(cache_hit_payload)
|
||
|
|
# Patch the route's references to both helpers — they're imported by name
|
||
|
|
# into router's namespace at module load time.
|
||
|
|
with (
|
||
|
|
patch.object(router, "get_llm_cache", return_value=fake),
|
||
|
|
patch.object(router, "choose_endpoint",
|
||
|
|
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||
|
|
):
|
||
|
|
resp = await client.post(
|
||
|
|
"/v1/chat/completions",
|
||
|
|
json={
|
||
|
|
"model": "test-model",
|
||
|
|
"messages": [{"role": "user", "content": "ping"}],
|
||
|
|
"stream": False,
|
||
|
|
"nomyo": {"cache": True},
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 200
|
||
|
|
# Body is streamed; collect it
|
||
|
|
body = resp.content
|
||
|
|
parsed = orjson.loads(body)
|
||
|
|
assert parsed["choices"][0]["message"]["content"] == "from-cache"
|
||
|
|
assert fake.calls and fake.calls[0][0] == "openai_chat"
|
||
|
|
|
||
|
|
async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload):
|
||
|
|
fake = _FakeCache(cache_hit_payload)
|
||
|
|
with (
|
||
|
|
patch.object(router, "get_llm_cache", return_value=fake),
|
||
|
|
patch.object(router, "choose_endpoint",
|
||
|
|
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||
|
|
):
|
||
|
|
resp = await client.post(
|
||
|
|
"/v1/chat/completions",
|
||
|
|
json={
|
||
|
|
"model": "test-model",
|
||
|
|
"messages": [{"role": "user", "content": "ping"}],
|
||
|
|
"stream": True,
|
||
|
|
"nomyo": {"cache": True},
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 200
|
||
|
|
assert resp.headers["content-type"].startswith("text/event-stream")
|
||
|
|
text = resp.content.decode()
|
||
|
|
# First SSE frame contains the cached content as a delta
|
||
|
|
first_frame = text.split("\n\n")[0]
|
||
|
|
assert first_frame.startswith("data: ")
|
||
|
|
chunk = orjson.loads(first_frame[len("data: "):])
|
||
|
|
assert chunk["choices"][0]["delta"]["content"] == "from-cache"
|
||
|
|
# Stream is terminated with [DONE]
|
||
|
|
assert "data: [DONE]" in text
|
||
|
|
|
||
|
|
async def test_cache_disabled_in_payload_bypasses_cache_check(self, client):
|
||
|
|
"""When nomyo.cache=False, get_chat is never called even if a cache exists."""
|
||
|
|
fake = _FakeCache(b"") # has a response, but should never be consulted
|
||
|
|
with (
|
||
|
|
patch.object(router, "get_llm_cache", return_value=fake),
|
||
|
|
patch.object(router, "choose_endpoint",
|
||
|
|
AsyncMock(side_effect=_BYPASS)),
|
||
|
|
):
|
||
|
|
resp = await client.post(
|
||
|
|
"/v1/chat/completions",
|
||
|
|
json={
|
||
|
|
"model": "m",
|
||
|
|
"messages": [{"role": "user", "content": "hi"}],
|
||
|
|
"nomyo": {"cache": False},
|
||
|
|
},
|
||
|
|
)
|
||
|
|
# Got past the cache short-circuit → endpoint selection invoked
|
||
|
|
assert resp.status_code == 599
|
||
|
|
assert fake.calls == []
|
||
|
|
|
||
|
|
async def test_no_cache_configured_bypasses_cache_check(self, client):
|
||
|
|
"""get_llm_cache() returning None should not break the route."""
|
||
|
|
with (
|
||
|
|
patch.object(router, "get_llm_cache", return_value=None),
|
||
|
|
patch.object(router, "choose_endpoint",
|
||
|
|
AsyncMock(side_effect=_BYPASS)),
|
||
|
|
):
|
||
|
|
resp = await client.post(
|
||
|
|
"/v1/chat/completions",
|
||
|
|
json={
|
||
|
|
"model": "m",
|
||
|
|
"messages": [{"role": "user", "content": "hi"}],
|
||
|
|
"nomyo": {"cache": True},
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 599
|
||
|
|
|
||
|
|
|
||
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
||
|
|
# /v1/completions
|
||
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
class TestOpenAICompletionsCacheHit:
|
||
|
|
async def test_nonstream_cache_hit(self, client, cache_hit_payload):
|
||
|
|
fake = _FakeCache(cache_hit_payload)
|
||
|
|
with (
|
||
|
|
patch.object(router, "get_llm_cache", return_value=fake),
|
||
|
|
patch.object(router, "choose_endpoint",
|
||
|
|
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||
|
|
):
|
||
|
|
resp = await client.post(
|
||
|
|
"/v1/completions",
|
||
|
|
json={
|
||
|
|
"model": "test-model",
|
||
|
|
"prompt": "Tell me a joke",
|
||
|
|
"stream": False,
|
||
|
|
"nomyo": {"cache": True},
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 200
|
||
|
|
# Prompt-style cache lookup is namespaced under "openai_completions"
|
||
|
|
assert fake.calls[0][0] == "openai_completions"
|
||
|
|
# Cache lookup receives the prompt as a single user message
|
||
|
|
cached_msgs = fake.calls[0][2]
|
||
|
|
assert cached_msgs == [{"role": "user", "content": "Tell me a joke"}]
|
||
|
|
|
||
|
|
async def test_stream_cache_hit(self, client, cache_hit_payload):
|
||
|
|
fake = _FakeCache(cache_hit_payload)
|
||
|
|
with (
|
||
|
|
patch.object(router, "get_llm_cache", return_value=fake),
|
||
|
|
patch.object(router, "choose_endpoint",
|
||
|
|
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||
|
|
):
|
||
|
|
resp = await client.post(
|
||
|
|
"/v1/completions",
|
||
|
|
json={
|
||
|
|
"model": "test-model",
|
||
|
|
"prompt": "What is 2+2?",
|
||
|
|
"stream": True,
|
||
|
|
"nomyo": {"cache": True},
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 200
|
||
|
|
assert resp.headers["content-type"].startswith("text/event-stream")
|
||
|
|
assert "data: [DONE]" in resp.content.decode()
|