feat: adding automated tests
This commit is contained in:
parent
e484f12228
commit
29ee360082
18 changed files with 2886 additions and 4 deletions
181
test/test_openai_proxies.py
Normal file
181
test/test_openai_proxies.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
"""Cache-hit short-circuit tests for the OpenAI-compatible proxy routes.
|
||||
|
||||
These tests verify that when the LLM cache reports a hit, the route returns
|
||||
the cached payload *without* selecting an endpoint or contacting any backend.
|
||||
"""
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import router
|
||||
|
||||
|
||||
_BYPASS = HTTPException(status_code=599, detail="bypassed")
|
||||
|
||||
|
||||
class _FakeCache:
|
||||
"""Minimal stand-in for cache.LLMCache.get_chat."""
|
||||
def __init__(self, response_bytes: bytes | None):
|
||||
self._resp = response_bytes
|
||||
self.calls: list[tuple] = []
|
||||
|
||||
async def get_chat(self, route, model, messages):
|
||||
self.calls.append((route, model, messages))
|
||||
return self._resp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_hit_payload():
|
||||
return orjson.dumps({
|
||||
"id": "cmpl-xyz",
|
||||
"created": 1,
|
||||
"model": "test-model",
|
||||
"choices": [{"message": {"role": "assistant", "content": "from-cache"}}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
})
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# /v1/chat/completions
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestOpenAIChatCompletionsCacheHit:
|
||||
async def test_nonstream_cache_hit_returns_cached_json(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
# Patch the route's references to both helpers — they're imported by name
|
||||
# into router's namespace at module load time.
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
"stream": False,
|
||||
"nomyo": {"cache": True},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
# Body is streamed; collect it
|
||||
body = resp.content
|
||||
parsed = orjson.loads(body)
|
||||
assert parsed["choices"][0]["message"]["content"] == "from-cache"
|
||||
assert fake.calls and fake.calls[0][0] == "openai_chat"
|
||||
|
||||
async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
"stream": True,
|
||||
"nomyo": {"cache": True},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("text/event-stream")
|
||||
text = resp.content.decode()
|
||||
# First SSE frame contains the cached content as a delta
|
||||
first_frame = text.split("\n\n")[0]
|
||||
assert first_frame.startswith("data: ")
|
||||
chunk = orjson.loads(first_frame[len("data: "):])
|
||||
assert chunk["choices"][0]["delta"]["content"] == "from-cache"
|
||||
# Stream is terminated with [DONE]
|
||||
assert "data: [DONE]" in text
|
||||
|
||||
async def test_cache_disabled_in_payload_bypasses_cache_check(self, client):
|
||||
"""When nomyo.cache=False, get_chat is never called even if a cache exists."""
|
||||
fake = _FakeCache(b"") # has a response, but should never be consulted
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=_BYPASS)),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "m",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"nomyo": {"cache": False},
|
||||
},
|
||||
)
|
||||
# Got past the cache short-circuit → endpoint selection invoked
|
||||
assert resp.status_code == 599
|
||||
assert fake.calls == []
|
||||
|
||||
async def test_no_cache_configured_bypasses_cache_check(self, client):
|
||||
"""get_llm_cache() returning None should not break the route."""
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=None),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=_BYPASS)),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "m",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"nomyo": {"cache": True},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 599
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# /v1/completions
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestOpenAICompletionsCacheHit:
|
||||
async def test_nonstream_cache_hit(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "Tell me a joke",
|
||||
"stream": False,
|
||||
"nomyo": {"cache": True},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
# Prompt-style cache lookup is namespaced under "openai_completions"
|
||||
assert fake.calls[0][0] == "openai_completions"
|
||||
# Cache lookup receives the prompt as a single user message
|
||||
cached_msgs = fake.calls[0][2]
|
||||
assert cached_msgs == [{"role": "user", "content": "Tell me a joke"}]
|
||||
|
||||
async def test_stream_cache_hit(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
with (
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/v1/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "What is 2+2?",
|
||||
"stream": True,
|
||||
"nomyo": {"cache": True},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("text/event-stream")
|
||||
assert "data: [DONE]" in resp.content.decode()
|
||||
Loading…
Add table
Add a link
Reference in a new issue