fix format

This commit is contained in:
cotran 2024-10-16 14:18:32 -07:00
parent 65543f8baf
commit 10f0a027bf
4 changed files with 94 additions and 36 deletions

View file

@ -5,6 +5,7 @@ from app.main import app # Assuming your FastAPI app is in main.py
client = TestClient(app)
# Unit tests for the health check endpoint
@pytest.mark.asyncio
async def test_healthz():
@ -12,6 +13,7 @@ async def test_healthz():
assert response.status_code == 200
assert response.json() == {"status": "ok"}
# Unit test for the models endpoint
@pytest.mark.asyncio
async def test_models():
@ -20,13 +22,11 @@ async def test_models():
assert response.json()["object"] == "list"
assert len(response.json()["data"]) > 0
# Unit test for embeddings endpoint
@pytest.mark.asyncio
async def test_embedding():
request_data = {
"input": "Test embedding",
"model": "katanemo/bge-large-en-v1.5"
}
request_data = {"input": "Test embedding", "model": "katanemo/bge-large-en-v1.5"}
response = client.post("/embeddings", json=request_data)
if request_data["model"] == "katanemo/bge-large-en-v1.5":
assert response.status_code == 200
@ -35,24 +35,23 @@ async def test_embedding():
else:
assert response.status_code == 400
# Unit test for the guard endpoint
@pytest.mark.asyncio
async def test_guard():
request_data = {
"input": "Test for jailbreak and toxicity",
"task": "jailbreak"
}
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
response = client.post("/guard", json=request_data)
assert response.status_code == 200
assert "jailbreak_verdict" in response.json()
# Unit test for the zero-shot endpoint
@pytest.mark.asyncio
async def test_zeroshot():
async def test_zeroshot():
request_data = {
"input": "Test input",
"labels": ["label1", "label2"],
"model": "katanemo/bart-large-mnli"
"model": "katanemo/bart-large-mnli",
}
response = client.post("/zeroshot", json=request_data)
if request_data["model"] == "katanemo/bart-large-mnli":
@ -61,13 +60,14 @@ async def test_zeroshot():
else:
assert response.status_code == 400
# Unit test for the hallucination endpoint
@pytest.mark.asyncio
async def test_hallucination():
request_data = {
"prompt": "Test hallucination",
"parameters": {"param1": "value1"},
"model": "katanemo/bart-large-mnli"
"model": "katanemo/bart-large-mnli",
}
response = client.post("/hallucination", json=request_data)
if request_data["model"] == "katanemo/bart-large-mnli":
@ -76,6 +76,7 @@ async def test_hallucination():
else:
assert response.status_code == 400
# Unit test for the chat completion endpoint
@pytest.mark.asyncio
async def test_chat_completion():
@ -84,7 +85,7 @@ async def test_chat_completion():
"messages": [{"role": "user", "content": "Hello!"}],
"model": "Arch-Function-1.5B",
"tools": [], # Assuming tools is part of the req as per the function
"metadata": {"x-arch-state": "[]"} # Assuming metadata is needed
"metadata": {"x-arch-state": "[]"}, # Assuming metadata is needed
}
response = await client.post("/v1/chat/completions", json=request_data)
assert response.status_code == 200