mirror of
https://github.com/katanemo/plano.git
synced 2026-06-26 15:39:40 +02:00
parent
dd8c43a392
commit
d5d79256b0
4 changed files with 12 additions and 21 deletions
|
|
@ -1,10 +0,0 @@
|
||||||
{
|
|
||||||
"toxic":{
|
|
||||||
"cpu": "katanemolabs/toxic_ovn_4bit",
|
|
||||||
"gpu": "katanemolabs/Bolt-Toxic-v1-eetq"
|
|
||||||
},
|
|
||||||
"jailbreak":{
|
|
||||||
"cpu": "katanemolabs/jailbreak_ovn_4bit",
|
|
||||||
"gpu": "katanemolabs/Bolt-Guard-EEtq"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
6
model_server/app/guard_model_config.yaml
Normal file
6
model_server/app/guard_model_config.yaml
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
toxic:
|
||||||
|
cpu: "katanemolabs/toxic_ovn_4bit"
|
||||||
|
gpu: "katanemolabs/Bolt-Toxic-v1-eetq"
|
||||||
|
jailbreak:
|
||||||
|
cpu: "katanemolabs/jailbreak_ovn_4bit"
|
||||||
|
gpu: "katanemolabs/Bolt-Guard-EEtq"
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import random
|
|
||||||
from fastapi import FastAPI, Response, HTTPException
|
from fastapi import FastAPI, Response, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from load_models import (
|
from load_models import (
|
||||||
|
|
@ -8,17 +7,11 @@ from load_models import (
|
||||||
load_zero_shot_models,
|
load_zero_shot_models,
|
||||||
)
|
)
|
||||||
from utils import GuardHandler, split_text_into_chunks
|
from utils import GuardHandler, split_text_into_chunks
|
||||||
import json
|
|
||||||
import string
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from datetime import datetime, date, timedelta, timezone
|
|
||||||
import string
|
import string
|
||||||
import pandas as pd
|
import time
|
||||||
from load_models import load_sql
|
|
||||||
import logging
|
import logging
|
||||||
from dateparser import parse
|
|
||||||
from network_data_generator import convert_to_ago_format, load_params
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
|
@ -31,8 +24,8 @@ zero_shot_models = load_zero_shot_models()
|
||||||
|
|
||||||
with open("/root/bolt_config.yaml", "r") as file:
|
with open("/root/bolt_config.yaml", "r") as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
with open("guard_model_config.json") as f:
|
with open("guard_model_config.yaml") as f:
|
||||||
guard_model_config = json.load(f)
|
guard_model_config = yaml.safe_load(f)
|
||||||
|
|
||||||
if "prompt_guards" in config.keys():
|
if "prompt_guards" in config.keys():
|
||||||
if len(config["prompt_guards"]["input_guards"]) == 2:
|
if len(config["prompt_guards"]["input_guards"]) == 2:
|
||||||
|
|
@ -147,6 +140,7 @@ async def guard(req: GuardRequest, res: Response):
|
||||||
"jailbreak_verdict": jailbreak_verdict,
|
"jailbreak_verdict": jailbreak_verdict,
|
||||||
"""
|
"""
|
||||||
max_words = 300
|
max_words = 300
|
||||||
|
start = time.time()
|
||||||
if req.task in ["both", "toxic", "jailbreak"]:
|
if req.task in ["both", "toxic", "jailbreak"]:
|
||||||
guard_handler.task = req.task
|
guard_handler.task = req.task
|
||||||
if len(req.input.split()) < max_words:
|
if len(req.input.split()) < max_words:
|
||||||
|
|
@ -194,6 +188,8 @@ async def guard(req: GuardRequest, res: Response):
|
||||||
final_result[f"{task}_prob"].append(
|
final_result[f"{task}_prob"].append(
|
||||||
result_chunk[f"{task}_prob"].item()
|
result_chunk[f"{task}_prob"].item()
|
||||||
)
|
)
|
||||||
|
end = time.time()
|
||||||
|
logger.info(f"Time taken for Guard: {end - start}")
|
||||||
return final_result
|
return final_result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -124,5 +124,4 @@ class GuardHandler:
|
||||||
f"{self.task}_verdict": verdict,
|
f"{self.task}_verdict": verdict,
|
||||||
f"{self.task}_sentence": sentence,
|
f"{self.task}_sentence": sentence,
|
||||||
}
|
}
|
||||||
print("Guard time : ", result_dict["time"])
|
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue