ArchFC endpoint integration (#94)

* integration

* mopdify docker file

* add params and fix python lint

* fix empty context and tool calls

* address comments

* revert port

* fix bug merge

* fix environment

* fix bug

* fix compose

* fix merge
This commit is contained in:
Co Tran 2024-10-01 12:47:26 -07:00 committed by GitHub
parent 1a7c1ad0a5
commit 17a643c410
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 98 additions and 41 deletions

View file

@ -6,6 +6,7 @@ from app.load_models import (
load_guard_model,
load_zero_shot_models,
)
import os
from app.utils import GuardHandler, split_text_into_chunks
import torch
import yaml
@ -25,14 +26,27 @@ zero_shot_models = load_zero_shot_models()
with open("guard_model_config.yaml") as f:
guard_model_config = yaml.safe_load(f)
with open('/root/arch_config.yaml') as f:
config = yaml.safe_load(f)
mode = os.getenv("MODE", "cloud")
logger.info(f"Serving model mode: {mode}")
if mode not in ['cloud', 'local-gpu', 'local-cpu']:
raise ValueError(f"Invalid mode: {mode}")
if mode == 'local-cpu':
hardware = 'cpu'
else:
hardware = "gpu" if torch.cuda.is_available() else "cpu"
task = "both"
hardware = "gpu" if torch.cuda.is_available() else "cpu"
jailbreak_model = load_guard_model(
guard_model_config["jailbreak"][hardware], hardware
)
if "prompt_guards" in config.keys():
task = list(config["prompt_guards"]["input_guards"].keys())[0]
guard_handler = GuardHandler(toxic_model=None, jailbreak_model=jailbreak_model)
hardware = "gpu" if torch.cuda.is_available() else "cpu"
jailbreak_model = load_guard_model(
guard_model_config["jailbreak"][hardware], hardware
)
toxic_model = None
guard_handler = GuardHandler(toxic_model=toxic_model, jailbreak_model=jailbreak_model)
app = FastAPI()