remove guard config json (#70)

* remove guard config json

* formating
This commit is contained in:
Co Tran 2024-09-24 13:33:31 -07:00 committed by GitHub
parent dd8c43a392
commit d5d79256b0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 12 additions and 21 deletions

View file

@ -1,4 +1,3 @@
import random
from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel
from load_models import (
@ -8,17 +7,11 @@ from load_models import (
load_zero_shot_models,
)
from utils import GuardHandler, split_text_into_chunks
import json
import string
import torch
import yaml
from datetime import datetime, date, timedelta, timezone
import string
import pandas as pd
from load_models import load_sql
import time
import logging
from dateparser import parse
from network_data_generator import convert_to_ago_format, load_params
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
@ -31,8 +24,8 @@ zero_shot_models = load_zero_shot_models()
with open("/root/bolt_config.yaml", "r") as file:
config = yaml.safe_load(file)
with open("guard_model_config.json") as f:
guard_model_config = json.load(f)
with open("guard_model_config.yaml") as f:
guard_model_config = yaml.safe_load(f)
if "prompt_guards" in config.keys():
if len(config["prompt_guards"]["input_guards"]) == 2:
@ -147,6 +140,7 @@ async def guard(req: GuardRequest, res: Response):
"jailbreak_verdict": jailbreak_verdict,
"""
max_words = 300
start = time.time()
if req.task in ["both", "toxic", "jailbreak"]:
guard_handler.task = req.task
if len(req.input.split()) < max_words:
@ -194,6 +188,8 @@ async def guard(req: GuardRequest, res: Response):
final_result[f"{task}_prob"].append(
result_chunk[f"{task}_prob"].item()
)
end = time.time()
logger.info(f"Time taken for Guard: {end - start}")
return final_result