mirror of
https://github.com/katanemo/plano.git
synced 2026-04-28 10:26:36 +02:00
* remove mode/hardware * add test and pre commit hook * add pytest dependieces * fix format * fix lint * fix precommit * fix pre commit * fix pre commit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit
106 lines
3.7 KiB
Python
106 lines
3.7 KiB
Python
import pytest
|
|
import httpx
|
|
from fastapi.testclient import TestClient
|
|
from app.main import app # Assuming your FastAPI app is in main.py
|
|
from unittest.mock import patch
|
|
import app.commons.globals as glb
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
client = TestClient(app)
|
|
|
|
logger.info(f"Model will be loaded on device: {glb.DEVICE}")
|
|
|
|
|
|
# Unit tests for the health check endpoint
|
|
@pytest.mark.asyncio
|
|
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
|
async def test_healthz():
|
|
response = client.get("/healthz")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"status": "ok"}
|
|
|
|
|
|
# Unit test for the models endpoint
|
|
@pytest.mark.asyncio
|
|
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
|
async def test_models():
|
|
response = client.get("/models")
|
|
assert response.status_code == 200
|
|
assert response.json()["object"] == "list"
|
|
assert len(response.json()["data"]) > 0
|
|
|
|
|
|
# Unit test for embeddings endpoint
|
|
@pytest.mark.asyncio
|
|
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
|
async def test_embedding():
|
|
request_data = {"input": "Test embedding", "model": "katanemo/bge-large-en-v1.5"}
|
|
response = client.post("/embeddings", json=request_data)
|
|
if request_data["model"] == "katanemo/bge-large-en-v1.5":
|
|
assert response.status_code == 200
|
|
assert response.json()["object"] == "list"
|
|
assert "data" in response.json()
|
|
else:
|
|
assert response.status_code == 400
|
|
|
|
|
|
# Unit test for the guard endpoint
|
|
@pytest.mark.asyncio
|
|
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
|
async def test_guard():
|
|
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
|
|
response = client.post("/guard", json=request_data)
|
|
assert response.status_code == 200
|
|
assert "jailbreak_verdict" in response.json()
|
|
|
|
|
|
# Unit test for the zero-shot endpoint
|
|
@pytest.mark.asyncio
|
|
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
|
async def test_zeroshot():
|
|
request_data = {
|
|
"input": "Test input",
|
|
"labels": ["label1", "label2"],
|
|
"model": "katanemo/bart-large-mnli",
|
|
}
|
|
response = client.post("/zeroshot", json=request_data)
|
|
if request_data["model"] == "katanemo/bart-large-mnli":
|
|
assert response.status_code == 200
|
|
assert "predicted_class" in response.json()
|
|
else:
|
|
assert response.status_code == 400
|
|
|
|
|
|
# Unit test for the hallucination endpoint
|
|
@pytest.mark.asyncio
|
|
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
|
async def test_hallucination():
|
|
request_data = {
|
|
"prompt": "Test hallucination",
|
|
"parameters": {"param1": "value1"},
|
|
"model": "katanemo/bart-large-mnli",
|
|
}
|
|
response = client.post("/hallucination", json=request_data)
|
|
if request_data["model"] == "katanemo/bart-large-mnli":
|
|
assert response.status_code == 200
|
|
assert "params_scores" in response.json()
|
|
else:
|
|
assert response.status_code == 400
|
|
|
|
|
|
# Unit test for the chat completion endpoint
|
|
@pytest.mark.asyncio
|
|
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
|
async def test_chat_completion():
|
|
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
|
|
request_data = {
|
|
"messages": [{"role": "user", "content": "Hello!"}],
|
|
"model": "Arch-Function-1.5B",
|
|
"tools": [], # Assuming tools is part of the req as per the function
|
|
"metadata": {"x-arch-state": "[]"}, # Assuming metadata is needed
|
|
}
|
|
response = await client.post("/v1/chat/completions", json=request_data)
|
|
assert response.status_code == 200
|
|
assert "choices" in response.json()
|