plano/model_server/app/tests/test_app.py
2024-10-16 14:15:13 -07:00

91 lines
3.1 KiB
Python

import pytest
import httpx
from fastapi.testclient import TestClient
from app.main import app # Assuming your FastAPI app is in main.py
client = TestClient(app)
# Unit tests for the health check endpoint
@pytest.mark.asyncio
async def test_healthz():
response = client.get("/healthz")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
# Unit test for the models endpoint
@pytest.mark.asyncio
async def test_models():
response = client.get("/models")
assert response.status_code == 200
assert response.json()["object"] == "list"
assert len(response.json()["data"]) > 0
# Unit test for embeddings endpoint
@pytest.mark.asyncio
async def test_embedding():
request_data = {
"input": "Test embedding",
"model": "katanemo/bge-large-en-v1.5"
}
response = client.post("/embeddings", json=request_data)
if request_data["model"] == "katanemo/bge-large-en-v1.5":
assert response.status_code == 200
assert response.json()["object"] == "list"
assert "data" in response.json()
else:
assert response.status_code == 400
# Unit test for the guard endpoint
@pytest.mark.asyncio
async def test_guard():
request_data = {
"input": "Test for jailbreak and toxicity",
"task": "jailbreak"
}
response = client.post("/guard", json=request_data)
assert response.status_code == 200
assert "jailbreak_verdict" in response.json()
# Unit test for the zero-shot endpoint
@pytest.mark.asyncio
async def test_zeroshot():
request_data = {
"input": "Test input",
"labels": ["label1", "label2"],
"model": "katanemo/bart-large-mnli"
}
response = client.post("/zeroshot", json=request_data)
if request_data["model"] == "katanemo/bart-large-mnli":
assert response.status_code == 200
assert "predicted_class" in response.json()
else:
assert response.status_code == 400
# Unit test for the hallucination endpoint
@pytest.mark.asyncio
async def test_hallucination():
request_data = {
"prompt": "Test hallucination",
"parameters": {"param1": "value1"},
"model": "katanemo/bart-large-mnli"
}
response = client.post("/hallucination", json=request_data)
if request_data["model"] == "katanemo/bart-large-mnli":
assert response.status_code == 200
assert "params_scores" in response.json()
else:
assert response.status_code == 400
# Unit test for the chat completion endpoint
@pytest.mark.asyncio
async def test_chat_completion():
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
request_data = {
"messages": [{"role": "user", "content": "Hello!"}],
"model": "Arch-Function-1.5B",
"tools": [], # Assuming tools is part of the req as per the function
"metadata": {"x-arch-state": "[]"} # Assuming metadata is needed
}
response = await client.post("/v1/chat/completions", json=request_data)
assert response.status_code == 200
assert "choices" in response.json()