mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fix format
This commit is contained in:
parent
65543f8baf
commit
10f0a027bf
4 changed files with 94 additions and 36 deletions
|
|
@ -5,6 +5,7 @@ from app.main import app # Assuming your FastAPI app is in main.py
|
|||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# Unit tests for the health check endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthz():
|
||||
|
|
@ -12,6 +13,7 @@ async def test_healthz():
|
|||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
# Unit test for the models endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_models():
|
||||
|
|
@ -20,13 +22,11 @@ async def test_models():
|
|||
assert response.json()["object"] == "list"
|
||||
assert len(response.json()["data"]) > 0
|
||||
|
||||
|
||||
# Unit test for embeddings endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding():
|
||||
request_data = {
|
||||
"input": "Test embedding",
|
||||
"model": "katanemo/bge-large-en-v1.5"
|
||||
}
|
||||
request_data = {"input": "Test embedding", "model": "katanemo/bge-large-en-v1.5"}
|
||||
response = client.post("/embeddings", json=request_data)
|
||||
if request_data["model"] == "katanemo/bge-large-en-v1.5":
|
||||
assert response.status_code == 200
|
||||
|
|
@ -35,24 +35,23 @@ async def test_embedding():
|
|||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# Unit test for the guard endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_guard():
|
||||
request_data = {
|
||||
"input": "Test for jailbreak and toxicity",
|
||||
"task": "jailbreak"
|
||||
}
|
||||
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
|
||||
response = client.post("/guard", json=request_data)
|
||||
assert response.status_code == 200
|
||||
assert "jailbreak_verdict" in response.json()
|
||||
|
||||
|
||||
# Unit test for the zero-shot endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_zeroshot():
|
||||
async def test_zeroshot():
|
||||
request_data = {
|
||||
"input": "Test input",
|
||||
"labels": ["label1", "label2"],
|
||||
"model": "katanemo/bart-large-mnli"
|
||||
"model": "katanemo/bart-large-mnli",
|
||||
}
|
||||
response = client.post("/zeroshot", json=request_data)
|
||||
if request_data["model"] == "katanemo/bart-large-mnli":
|
||||
|
|
@ -61,13 +60,14 @@ async def test_zeroshot():
|
|||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# Unit test for the hallucination endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_hallucination():
|
||||
request_data = {
|
||||
"prompt": "Test hallucination",
|
||||
"parameters": {"param1": "value1"},
|
||||
"model": "katanemo/bart-large-mnli"
|
||||
"model": "katanemo/bart-large-mnli",
|
||||
}
|
||||
response = client.post("/hallucination", json=request_data)
|
||||
if request_data["model"] == "katanemo/bart-large-mnli":
|
||||
|
|
@ -76,6 +76,7 @@ async def test_hallucination():
|
|||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# Unit test for the chat completion endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion():
|
||||
|
|
@ -84,7 +85,7 @@ async def test_chat_completion():
|
|||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"model": "Arch-Function-1.5B",
|
||||
"tools": [], # Assuming tools is part of the req as per the function
|
||||
"metadata": {"x-arch-state": "[]"} # Assuming metadata is needed
|
||||
"metadata": {"x-arch-state": "[]"}, # Assuming metadata is needed
|
||||
}
|
||||
response = await client.post("/v1/chat/completions", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
|
|
|||
|
|
@ -11,12 +11,15 @@ arch_guard_model_type = {
|
|||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
# Mock environment variables
|
||||
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
|
||||
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
|
||||
|
||||
|
||||
# Test for get_embedding_model function
|
||||
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
|
||||
@patch("app.loader.AutoModel.from_pretrained")
|
||||
|
|
@ -30,11 +33,18 @@ def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, moc
|
|||
|
||||
# Assertions
|
||||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
|
||||
assert mock_tokenizer.called_once_with("katanemo/bge-large-en-v1.5", trust_remote_code=True)
|
||||
assert mock_tokenizer.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with("katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx")
|
||||
assert mock_ort_model.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_automodel.called_once_with("katanemo/bge-large-en-v1.5", device_map=glb.DEVICE)
|
||||
assert mock_automodel.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
|
||||
)
|
||||
|
||||
|
||||
# Test for get_zero_shot_model function
|
||||
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
|
||||
|
|
@ -51,10 +61,13 @@ def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock
|
|||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
|
||||
assert mock_tokenizer.called_once_with("katanemo/bart-large-mnli")
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with("katanemo/bart-large-mnli", file_name="onnx/model.onnx")
|
||||
assert mock_ort_model.called_once_with(
|
||||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_pipeline.called_once()
|
||||
|
||||
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
|
|
@ -65,19 +78,25 @@ def test_get_prompt_guard(mock_ov_model, mock_auto_model, mock_tokenizer):
|
|||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
assert mock_tokenizer.called_once_with(arch_guard_model_type[glb.DEVICE], trust_remote_code=True)
|
||||
assert mock_tokenizer.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
assert mock_ov_model.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], device_map=glb.DEVICE, low_cpu_mem_usage=True
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
assert mock_auto_model.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], device_map=glb.DEVICE, low_cpu_mem_usage=True
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,12 +11,15 @@ arch_guard_model_type = {
|
|||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
# Mock environment variables
|
||||
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
|
||||
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
|
||||
|
||||
|
||||
# Test for get_embedding_model function
|
||||
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
|
||||
@patch("app.loader.AutoModel.from_pretrained")
|
||||
|
|
@ -30,11 +33,18 @@ def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, moc
|
|||
|
||||
# Assertions
|
||||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
|
||||
assert mock_tokenizer.called_once_with("katanemo/bge-large-en-v1.5", trust_remote_code=True)
|
||||
assert mock_tokenizer.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with("katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx")
|
||||
assert mock_ort_model.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_automodel.called_once_with("katanemo/bge-large-en-v1.5", device_map=glb.DEVICE)
|
||||
assert mock_automodel.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
|
||||
)
|
||||
|
||||
|
||||
# Test for get_zero_shot_model function
|
||||
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
|
||||
|
|
@ -51,10 +61,13 @@ def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock
|
|||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
|
||||
assert mock_tokenizer.called_once_with("katanemo/bart-large-mnli")
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with("katanemo/bart-large-mnli", file_name="onnx/model.onnx")
|
||||
assert mock_ort_model.called_once_with(
|
||||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_pipeline.called_once()
|
||||
|
||||
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
|
|
@ -65,19 +78,25 @@ def test_get_prompt_guard(mock_ov_model, mock_auto_model, mock_tokenizer):
|
|||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
assert mock_tokenizer.called_once_with(arch_guard_model_type[glb.DEVICE], trust_remote_code=True)
|
||||
assert mock_tokenizer.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
assert mock_ov_model.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], device_map=glb.DEVICE, low_cpu_mem_usage=True
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
assert mock_auto_model.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], device_map=glb.DEVICE, low_cpu_mem_usage=True
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,12 +11,15 @@ arch_guard_model_type = {
|
|||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
# Mock environment variables
|
||||
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
|
||||
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
|
||||
|
||||
|
||||
# Test for get_embedding_model function
|
||||
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
|
||||
@patch("app.loader.AutoModel.from_pretrained")
|
||||
|
|
@ -30,11 +33,18 @@ def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, moc
|
|||
|
||||
# Assertions
|
||||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
|
||||
assert mock_tokenizer.called_once_with("katanemo/bge-large-en-v1.5", trust_remote_code=True)
|
||||
assert mock_tokenizer.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with("katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx")
|
||||
assert mock_ort_model.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_automodel.called_once_with("katanemo/bge-large-en-v1.5", device_map=glb.DEVICE)
|
||||
assert mock_automodel.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
|
||||
)
|
||||
|
||||
|
||||
# Test for get_zero_shot_model function
|
||||
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
|
||||
|
|
@ -51,10 +61,13 @@ def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock
|
|||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
|
||||
assert mock_tokenizer.called_once_with("katanemo/bart-large-mnli")
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with("katanemo/bart-large-mnli", file_name="onnx/model.onnx")
|
||||
assert mock_ort_model.called_once_with(
|
||||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_pipeline.called_once()
|
||||
|
||||
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
|
|
@ -65,19 +78,25 @@ def test_get_prompt_guard(mock_ov_model, mock_auto_model, mock_tokenizer):
|
|||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
assert mock_tokenizer.called_once_with(arch_guard_model_type[glb.DEVICE], trust_remote_code=True)
|
||||
assert mock_tokenizer.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
assert mock_ov_model.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], device_map=glb.DEVICE, low_cpu_mem_usage=True
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
assert mock_auto_model.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], device_map=glb.DEVICE, low_cpu_mem_usage=True
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue