231 lines
8.1 KiB
Python
231 lines
8.1 KiB
Python
|
|
"""
|
||
|
|
HTTP-level validation and auth middleware tests.
|
||
|
|
|
||
|
|
These tests use an in-process httpx client and never reach a real backend:
|
||
|
|
all requests are rejected at the validation or auth layer before any
|
||
|
|
endpoint-selection or upstream HTTP calls occur.
|
||
|
|
"""
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
|
||
|
|
class TestChatValidation:
|
||
|
|
async def test_missing_model_returns_400(self, client):
|
||
|
|
resp = await client.post(
|
||
|
|
"/api/chat",
|
||
|
|
json={"messages": [{"role": "user", "content": "hello"}]},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
assert "model" in resp.json()["detail"].lower()
|
||
|
|
|
||
|
|
async def test_missing_messages_returns_400(self, client):
|
||
|
|
resp = await client.post("/api/chat", json={"model": "llama3.2"})
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
async def test_invalid_json_returns_400(self, client):
|
||
|
|
resp = await client.post(
|
||
|
|
"/api/chat",
|
||
|
|
content=b"not-json",
|
||
|
|
headers={"Content-Type": "application/json"},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
async def test_messages_not_list_returns_400(self, client):
|
||
|
|
resp = await client.post(
|
||
|
|
"/api/chat",
|
||
|
|
json={"model": "m", "messages": "not-a-list"},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
async def test_options_not_dict_returns_400(self, client):
|
||
|
|
resp = await client.post(
|
||
|
|
"/api/chat",
|
||
|
|
json={"model": "m", "messages": [{"role": "user", "content": "hi"}], "options": "bad"},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
|
||
|
|
class TestGenerateValidation:
|
||
|
|
async def test_missing_model_returns_400(self, client):
|
||
|
|
resp = await client.post("/api/generate", json={"prompt": "hello"})
|
||
|
|
assert resp.status_code == 400
|
||
|
|
assert "model" in resp.json()["detail"].lower()
|
||
|
|
|
||
|
|
async def test_missing_prompt_returns_400(self, client):
|
||
|
|
resp = await client.post("/api/generate", json={"model": "m"})
|
||
|
|
assert resp.status_code == 400
|
||
|
|
assert "prompt" in resp.json()["detail"].lower()
|
||
|
|
|
||
|
|
async def test_invalid_json_returns_400(self, client):
|
||
|
|
resp = await client.post(
|
||
|
|
"/api/generate",
|
||
|
|
content=b"{bad-json",
|
||
|
|
headers={"Content-Type": "application/json"},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
|
||
|
|
class TestEmbedValidation:
|
||
|
|
async def test_missing_model_returns_400(self, client):
|
||
|
|
resp = await client.post("/api/embed", json={"input": "hello"})
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
async def test_missing_input_returns_400(self, client):
|
||
|
|
resp = await client.post("/api/embed", json={"model": "nomic-embed-text"})
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
|
||
|
|
class TestEmbeddingsValidation:
|
||
|
|
async def test_missing_model_returns_400(self, client):
|
||
|
|
resp = await client.post("/api/embeddings", json={"prompt": "hello"})
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
async def test_missing_prompt_returns_400(self, client):
|
||
|
|
resp = await client.post("/api/embeddings", json={"model": "nomic-embed-text"})
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
|
||
|
|
class TestOpenAIChatValidation:
|
||
|
|
async def test_missing_model_returns_400(self, client):
|
||
|
|
resp = await client.post(
|
||
|
|
"/v1/chat/completions",
|
||
|
|
json={"messages": [{"role": "user", "content": "hello"}]},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
async def test_missing_messages_returns_400(self, client):
|
||
|
|
resp = await client.post(
|
||
|
|
"/v1/chat/completions",
|
||
|
|
json={"model": "gpt-4o"},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
async def test_invalid_json_returns_400(self, client):
|
||
|
|
resp = await client.post(
|
||
|
|
"/v1/chat/completions",
|
||
|
|
content=b"}{",
|
||
|
|
headers={"Content-Type": "application/json"},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
async def test_svg_image_rejected(self, client):
|
||
|
|
resp = await client.post(
|
||
|
|
"/v1/chat/completions",
|
||
|
|
json={
|
||
|
|
"model": "vision-model",
|
||
|
|
"messages": [{
|
||
|
|
"role": "user",
|
||
|
|
"content": [
|
||
|
|
{"type": "text", "text": "describe"},
|
||
|
|
{"type": "image_url", "image_url": {"url": "data:image/svg+xml;base64,abc"}},
|
||
|
|
],
|
||
|
|
}],
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
assert "svg" in resp.json()["detail"].lower()
|
||
|
|
|
||
|
|
|
||
|
|
class TestOpenAICompletionsValidation:
|
||
|
|
async def test_missing_model_returns_400(self, client):
|
||
|
|
resp = await client.post("/v1/completions", json={"prompt": "hello"})
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
async def test_missing_prompt_returns_400(self, client):
|
||
|
|
resp = await client.post("/v1/completions", json={"model": "m"})
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
|
||
|
|
class TestRerankValidation:
|
||
|
|
async def test_missing_model_returns_400(self, client):
|
||
|
|
resp = await client.post(
|
||
|
|
"/v1/rerank",
|
||
|
|
json={"query": "search query", "documents": ["doc1"]},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
async def test_missing_query_returns_400(self, client):
|
||
|
|
resp = await client.post(
|
||
|
|
"/v1/rerank",
|
||
|
|
json={"model": "reranker", "documents": ["doc1"]},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
async def test_empty_documents_returns_400(self, client):
|
||
|
|
resp = await client.post(
|
||
|
|
"/v1/rerank",
|
||
|
|
json={"model": "reranker", "query": "search", "documents": []},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
|
||
|
|
class TestShowValidation:
|
||
|
|
async def test_missing_model_returns_400(self, client):
|
||
|
|
resp = await client.post("/api/show", json={})
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
|
||
|
|
class TestCopyValidation:
|
||
|
|
async def test_missing_source_returns_400(self, client):
|
||
|
|
resp = await client.post("/api/copy", json={"destination": "dst"})
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
async def test_missing_destination_returns_400(self, client):
|
||
|
|
resp = await client.post("/api/copy", json={"source": "src"})
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
|
||
|
|
class TestDeleteValidation:
|
||
|
|
async def test_missing_model_returns_400(self, client):
|
||
|
|
import json as _json
|
||
|
|
resp = await client.request(
|
||
|
|
"DELETE",
|
||
|
|
"/api/delete",
|
||
|
|
content=_json.dumps({}).encode(),
|
||
|
|
headers={"Content-Type": "application/json"},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
|
||
|
|
|
||
|
|
class TestAuthMiddleware:
|
||
|
|
async def test_no_key_returns_401(self, client_auth):
|
||
|
|
resp = await client_auth.post(
|
||
|
|
"/api/chat",
|
||
|
|
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 401
|
||
|
|
assert "Missing" in resp.json()["detail"]
|
||
|
|
|
||
|
|
async def test_invalid_key_returns_403(self, client_auth):
|
||
|
|
resp = await client_auth.post(
|
||
|
|
"/api/chat",
|
||
|
|
headers={"Authorization": "Bearer wrong-key"},
|
||
|
|
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 403
|
||
|
|
assert "Invalid" in resp.json()["detail"]
|
||
|
|
|
||
|
|
async def test_valid_key_passes_middleware(self, client_auth):
|
||
|
|
# /api/usage reads in-memory counters only — no backend call needed
|
||
|
|
resp = await client_auth.get(
|
||
|
|
"/api/usage",
|
||
|
|
headers={"Authorization": "Bearer test-secret-key"},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 200
|
||
|
|
|
||
|
|
async def test_key_via_query_param(self, client_auth):
|
||
|
|
resp = await client_auth.get("/api/usage?api_key=test-secret-key")
|
||
|
|
assert resp.status_code == 200
|
||
|
|
|
||
|
|
async def test_options_bypasses_auth(self, client_auth):
|
||
|
|
resp = await client_auth.options("/api/chat")
|
||
|
|
assert resp.status_code not in (401, 403)
|
||
|
|
|
||
|
|
async def test_root_path_bypasses_auth(self, client_auth):
|
||
|
|
resp = await client_auth.get("/")
|
||
|
|
assert resp.status_code not in (401, 403)
|
||
|
|
|
||
|
|
async def test_favicon_bypasses_auth(self, client_auth):
|
||
|
|
resp = await client_auth.get("/favicon.ico")
|
||
|
|
# Should not be blocked by auth (may 404 in test but not 401/403)
|
||
|
|
assert resp.status_code not in (401, 403)
|