remove guard config json (#70)

* remove guard config json

* formating
This commit is contained in:
Co Tran 2024-09-24 13:33:31 -07:00 committed by GitHub
parent dd8c43a392
commit d5d79256b0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 12 additions and 21 deletions

View file

@ -1,10 +0,0 @@
{
"toxic":{
"cpu": "katanemolabs/toxic_ovn_4bit",
"gpu": "katanemolabs/Bolt-Toxic-v1-eetq"
},
"jailbreak":{
"cpu": "katanemolabs/jailbreak_ovn_4bit",
"gpu": "katanemolabs/Bolt-Guard-EEtq"
}
}

View file

@ -0,0 +1,6 @@
toxic:
cpu: "katanemolabs/toxic_ovn_4bit"
gpu: "katanemolabs/Bolt-Toxic-v1-eetq"
jailbreak:
cpu: "katanemolabs/jailbreak_ovn_4bit"
gpu: "katanemolabs/Bolt-Guard-EEtq"

View file

@ -1,4 +1,3 @@
import random
from fastapi import FastAPI, Response, HTTPException from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from load_models import ( from load_models import (
@ -8,17 +7,11 @@ from load_models import (
load_zero_shot_models, load_zero_shot_models,
) )
from utils import GuardHandler, split_text_into_chunks from utils import GuardHandler, split_text_into_chunks
import json
import string
import torch import torch
import yaml import yaml
from datetime import datetime, date, timedelta, timezone
import string import string
import pandas as pd import time
from load_models import load_sql
import logging import logging
from dateparser import parse
from network_data_generator import convert_to_ago_format, load_params
logging.basicConfig( logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
@ -31,8 +24,8 @@ zero_shot_models = load_zero_shot_models()
with open("/root/bolt_config.yaml", "r") as file: with open("/root/bolt_config.yaml", "r") as file:
config = yaml.safe_load(file) config = yaml.safe_load(file)
with open("guard_model_config.json") as f: with open("guard_model_config.yaml") as f:
guard_model_config = json.load(f) guard_model_config = yaml.safe_load(f)
if "prompt_guards" in config.keys(): if "prompt_guards" in config.keys():
if len(config["prompt_guards"]["input_guards"]) == 2: if len(config["prompt_guards"]["input_guards"]) == 2:
@ -147,6 +140,7 @@ async def guard(req: GuardRequest, res: Response):
"jailbreak_verdict": jailbreak_verdict, "jailbreak_verdict": jailbreak_verdict,
""" """
max_words = 300 max_words = 300
start = time.time()
if req.task in ["both", "toxic", "jailbreak"]: if req.task in ["both", "toxic", "jailbreak"]:
guard_handler.task = req.task guard_handler.task = req.task
if len(req.input.split()) < max_words: if len(req.input.split()) < max_words:
@ -194,6 +188,8 @@ async def guard(req: GuardRequest, res: Response):
final_result[f"{task}_prob"].append( final_result[f"{task}_prob"].append(
result_chunk[f"{task}_prob"].item() result_chunk[f"{task}_prob"].item()
) )
end = time.time()
logger.info(f"Time taken for Guard: {end - start}")
return final_result return final_result

View file

@ -124,5 +124,4 @@ class GuardHandler:
f"{self.task}_verdict": verdict, f"{self.task}_verdict": verdict,
f"{self.task}_sentence": sentence, f"{self.task}_sentence": sentence,
} }
print("Guard time : ", result_dict["time"])
return result_dict return result_dict