diff --git a/model_server/app/guard_model_config.json b/model_server/app/guard_model_config.json deleted file mode 100644 index a0ed0e39..00000000 --- a/model_server/app/guard_model_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "toxic":{ - "cpu": "katanemolabs/toxic_ovn_4bit", - "gpu": "katanemolabs/Bolt-Toxic-v1-eetq" - }, - "jailbreak":{ - "cpu": "katanemolabs/jailbreak_ovn_4bit", - "gpu": "katanemolabs/Bolt-Guard-EEtq" - } -} \ No newline at end of file diff --git a/model_server/app/guard_model_config.yaml b/model_server/app/guard_model_config.yaml new file mode 100644 index 00000000..842a54c6 --- /dev/null +++ b/model_server/app/guard_model_config.yaml @@ -0,0 +1,6 @@ +toxic: + cpu: "katanemolabs/toxic_ovn_4bit" + gpu: "katanemolabs/Bolt-Toxic-v1-eetq" +jailbreak: + cpu: "katanemolabs/jailbreak_ovn_4bit" + gpu: "katanemolabs/Bolt-Guard-EEtq" diff --git a/model_server/app/main.py b/model_server/app/main.py index 4eea7b8a..2c83d769 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -1,4 +1,3 @@ -import random from fastapi import FastAPI, Response, HTTPException from pydantic import BaseModel from load_models import ( @@ -8,17 +7,11 @@ from load_models import ( load_zero_shot_models, ) from utils import GuardHandler, split_text_into_chunks -import json -import string import torch import yaml -from datetime import datetime, date, timedelta, timezone import string -import pandas as pd -from load_models import load_sql +import time import logging -from dateparser import parse -from network_data_generator import convert_to_ago_format, load_params logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" @@ -31,8 +24,8 @@ zero_shot_models = load_zero_shot_models() with open("/root/bolt_config.yaml", "r") as file: config = yaml.safe_load(file) -with open("guard_model_config.json") as f: - guard_model_config = json.load(f) +with open("guard_model_config.yaml") as f: + guard_model_config = yaml.safe_load(f) if "prompt_guards" in config.keys(): if len(config["prompt_guards"]["input_guards"]) == 2: @@ -147,6 +140,7 @@ async def guard(req: GuardRequest, res: Response): "jailbreak_verdict": jailbreak_verdict, """ max_words = 300 + start = time.time() if req.task in ["both", "toxic", "jailbreak"]: guard_handler.task = req.task if len(req.input.split()) < max_words: @@ -194,6 +188,8 @@ async def guard(req: GuardRequest, res: Response): final_result[f"{task}_prob"].append( result_chunk[f"{task}_prob"].item() ) + end = time.time() + logger.info(f"Time taken for Guard: {end - start}") return final_result diff --git a/model_server/app/utils.py b/model_server/app/utils.py index 66c0d254..0af58f9c 100644 --- a/model_server/app/utils.py +++ b/model_server/app/utils.py @@ -124,5 +124,4 @@ class GuardHandler: f"{self.task}_verdict": verdict, f"{self.task}_sentence": sentence, } - print("Guard time : ", result_dict["time"]) return result_dict