mirror of
https://github.com/katanemo/plano.git
synced 2026-04-28 18:36:34 +02:00
release 0.1.2 (#266)
This commit is contained in:
parent
31749bfc74
commit
d1dd8710a4
10 changed files with 4090 additions and 502 deletions
|
|
@ -33,15 +33,15 @@ def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, moc
|
|||
|
||||
# Assertions
|
||||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
|
||||
assert mock_tokenizer.called_once_with(
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with(
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_automodel.called_once_with(
|
||||
mock_automodel.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
|
||||
)
|
||||
|
||||
|
|
@ -59,9 +59,9 @@ def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock
|
|||
|
||||
# Assertions
|
||||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
|
||||
assert mock_tokenizer.called_once_with("katanemo/bart-large-mnli")
|
||||
mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli")
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with(
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
|
|
@ -72,7 +72,7 @@ def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock
|
|||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_get_prompt_guard(mock_ov_model, mock_auto_model, mock_tokenizer):
|
||||
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
# Mock model based on device
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
|
|
@ -85,17 +85,17 @@ def test_get_prompt_guard(mock_ov_model, mock_auto_model, mock_tokenizer):
|
|||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
assert mock_tokenizer.called_once_with(
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
assert mock_ov_model.called_once_with(
|
||||
mock_ov_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
assert mock_auto_model.called_once_with(
|
||||
mock_auto_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue