mirror of
https://github.com/katanemo/plano.git
synced 2026-04-29 02:46:28 +02:00
Refactor model server hardware config + add unit tests to load/request to the server (#189)
* remove mode/hardware * add test and pre commit hook * add pytest dependieces * fix format * fix lint * fix precommit * fix pre commit * fix pre commit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit
This commit is contained in:
parent
3bd2ffe9fb
commit
8e54ac20d8
13 changed files with 480 additions and 43 deletions
|
|
@ -18,14 +18,16 @@ arch_function_generation_params = {
|
|||
"stop_token_ids": [151645],
|
||||
}
|
||||
|
||||
arch_guard_model_type = {"cpu": "katanemo/Arch-Guard-cpu", "gpu": "katanemo/Arch-Guard"}
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
# Model definition
|
||||
embedding_model = loader.get_embedding_model()
|
||||
zero_shot_model = loader.get_zero_shot_model()
|
||||
|
||||
prompt_guard_dict = loader.get_prompt_guard(
|
||||
arch_guard_model_type[glb.HARDWARE], glb.HARDWARE
|
||||
)
|
||||
prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue