mirror of
https://github.com/katanemo/plano.git
synced 2026-05-05 13:53:03 +02:00
Refactor model server hardware config + add unit tests to load/request to the server (#189)
* remove mode/hardware * add test and pre commit hook * add pytest dependieces * fix format * fix lint * fix precommit * fix pre commit * fix pre commit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit
This commit is contained in:
parent
3bd2ffe9fb
commit
8e54ac20d8
13 changed files with 480 additions and 43 deletions
|
|
@ -18,14 +18,16 @@ arch_function_generation_params = {
|
|||
"stop_token_ids": [151645],
|
||||
}
|
||||
|
||||
arch_guard_model_type = {"cpu": "katanemo/Arch-Guard-cpu", "gpu": "katanemo/Arch-Guard"}
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
# Model definition
|
||||
embedding_model = loader.get_embedding_model()
|
||||
zero_shot_model = loader.get_zero_shot_model()
|
||||
|
||||
prompt_guard_dict = loader.get_prompt_guard(
|
||||
arch_guard_model_type[glb.HARDWARE], glb.HARDWARE
|
||||
)
|
||||
prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
|
||||
|
|
|
|||
|
|
@ -2,5 +2,3 @@ import app.commons.utilities as utils
|
|||
|
||||
|
||||
DEVICE = utils.get_device()
|
||||
MODE = utils.get_serving_mode()
|
||||
HARDWARE = utils.get_hardware(MODE)
|
||||
|
|
|
|||
|
|
@ -22,9 +22,11 @@ def get_device():
|
|||
available_device = {
|
||||
"cpu": True,
|
||||
"cuda": torch.cuda.is_available(),
|
||||
"mps": torch.backends.mps.is_available()
|
||||
if hasattr(torch.backends, "mps")
|
||||
else False,
|
||||
"mps": (
|
||||
torch.backends.mps.is_available()
|
||||
if hasattr(torch.backends, "mps")
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
||||
if available_device["cuda"]:
|
||||
|
|
@ -37,24 +39,6 @@ def get_device():
|
|||
return device
|
||||
|
||||
|
||||
def get_serving_mode():
|
||||
mode = os.getenv("MODE", "cloud")
|
||||
|
||||
if mode not in ["cloud", "local-gpu", "local-cpu"]:
|
||||
raise ValueError(f"Invalid serving mode: {mode}")
|
||||
|
||||
return mode
|
||||
|
||||
|
||||
def get_hardware(mode):
|
||||
if mode == "local-cpu":
|
||||
hardware = "cpu"
|
||||
else:
|
||||
hardware = "gpu" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
return hardware
|
||||
|
||||
|
||||
def get_client(endpoint):
|
||||
client = OpenAI(base_url=endpoint, api_key="EMPTY")
|
||||
return client
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue