Cotran/onnx conversion (#145)

* onnx replacement

* onnx conversion for nli and embedding model

* fix naming

* fix naming

* fix naming

* pin version
This commit is contained in:
Co Tran 2024-10-08 14:37:48 -07:00 committed by GitHub
parent b30ad791f7
commit 80d2229053
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 61 additions and 42 deletions

View file

@ -1,5 +1,5 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5"; pub const DEFAULT_EMBEDDING_MODEL: &str = "katanemo/bge-large-en-v1.5-onnx";
pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli"; pub const DEFAULT_INTENT_MODEL: &str = "katanemo/deberta-base-nli-onnx";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8; pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.1; pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.1;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector"; pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";

View file

@ -15,7 +15,7 @@ WORKDIR /src
# specify list of models that will go into the image as a comma separated list # specify list of models that will go into the image as a comma separated list
# following models have been tested to work with this image # following models have been tested to work with this image
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small" # "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
ENV MODELS="BAAI/bge-large-en-v1.5" ENV MODELS="katanemo/bge-large-en-v1.5-onnx"
COPY ./app ./app COPY ./app ./app
COPY ./app/guard_model_config.yaml . COPY ./app/guard_model_config.yaml .

View file

@ -45,7 +45,7 @@ RUN if command -v nvcc >/dev/null 2>&1; then \
COPY . /src COPY . /src
# Specify list of models that will go into the image as a comma separated list # Specify list of models that will go into the image as a comma separated list
ENV MODELS="BAAI/bge-large-en-v1.5" ENV MODELS="katanemo/bge-large-en-v1.5-onnx"
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
COPY /app /app COPY /app /app

View file

@ -1,3 +1,3 @@
jailbreak: jailbreak:
cpu: "katanemolabs/Arch-Guard-cpu" cpu: "katanemo/Arch-Guard-cpu"
gpu: "katanemolabs/Arch-Guard-gpu" gpu: "katanemo/Arch-Guard-gpu"

View file

@ -3,6 +3,7 @@ import sentence_transformers
from transformers import AutoTokenizer, pipeline from transformers import AutoTokenizer, pipeline
import sqlite3 import sqlite3
import torch import torch
from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForSequenceClassification # type: ignore
def get_device(): def get_device():
@ -16,13 +17,14 @@ def get_device():
return device return device
def load_transformers(models=os.getenv("MODELS", "BAAI/bge-large-en-v1.5")): def load_transformers(model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx")):
transformers = {} transformers = {}
device = get_device() device = get_device()
for model in models.split(","): transformers["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
transformers[model] = sentence_transformers.SentenceTransformer( transformers["model"] = ORTModelForFeatureExtraction.from_pretrained(
model, device=device model_name, device_map = device
) )
transformers["model_name"] = model_name
return transformers return transformers
@ -31,16 +33,16 @@ def load_guard_model(
model_name, model_name,
hardware_config="cpu", hardware_config="cpu",
): ):
guard_mode = {} guard_model = {}
guard_mode["tokenizer"] = AutoTokenizer.from_pretrained( guard_model["tokenizer"] = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True model_name, trust_remote_code=True
) )
guard_mode["model_name"] = model_name guard_model["model_name"] = model_name
if hardware_config == "cpu": if hardware_config == "cpu":
from optimum.intel import OVModelForSequenceClassification from optimum.intel import OVModelForSequenceClassification
device = "cpu" device = "cpu"
guard_mode["model"] = OVModelForSequenceClassification.from_pretrained( guard_model["model"] = OVModelForSequenceClassification.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True model_name, device_map=device, low_cpu_mem_usage=True
) )
elif hardware_config == "gpu": elif hardware_config == "gpu":
@ -48,25 +50,34 @@ def load_guard_model(
import torch import torch
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
guard_mode["model"] = AutoModelForSequenceClassification.from_pretrained( guard_model["model"] = AutoModelForSequenceClassification.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True model_name, device_map=device, low_cpu_mem_usage=True
) )
guard_mode["device"] = device guard_model["device"] = device
guard_mode["hardware_config"] = hardware_config guard_model["hardware_config"] = hardware_config
return guard_mode return guard_model
def load_zero_shot_models( def load_zero_shot_models(
models=os.getenv("ZERO_SHOT_MODELS", "tasksource/deberta-base-long-nli") model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")
): ):
zero_shot_models = {} zero_shot_model = {}
device = get_device() device = get_device()
for model in models.split(","): zero_shot_model["model"] = ORTModelForSequenceClassification.from_pretrained(
zero_shot_models[model] = pipeline( model_name
"zero-shot-classification", model=model, device=device )
) zero_shot_model["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
return zero_shot_models # create pipeline
zero_shot_model["pipeline"] = pipeline(
"zero-shot-classification",
model=zero_shot_model["model"],
tokenizer=zero_shot_model["tokenizer"],
device=device,
)
zero_shot_model["model_name"] = model_name
return zero_shot_model
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,4 +1,3 @@
import os
from fastapi import FastAPI, Response, HTTPException from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from app.load_models import ( from app.load_models import (
@ -53,21 +52,25 @@ async def healthz():
async def models(): async def models():
models = [] models = []
for model in transformers.keys(): models.append({"id": transformers["model_name"], "object": "model"})
models.append({"id": model, "object": "model"})
return {"data": models, "object": "list"} return {"data": models, "object": "list"}
@app.post("/embeddings") @app.post("/embeddings")
async def embedding(req: EmbeddingRequest, res: Response): async def embedding(req: EmbeddingRequest, res: Response):
if req.model != transformers["model_name"]:
if req.model not in transformers:
raise HTTPException(status_code=400, detail="unknown model: " + req.model) raise HTTPException(status_code=400, detail="unknown model: " + req.model)
start = time.time() start = time.time()
embeddings = transformers[req.model].encode([req.input]) encoded_input = transformers["tokenizer"](
logger.info(f"Embedding Call Complete Time: {time.time()-start}") req.input, padding=True, truncation=True, return_tensors="pt"
)
embeddings = transformers["model"](**encoded_input)
embeddings = embeddings[0][:, 0]
# normalize embeddings
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().numpy()
print(f"Embedding Call Complete Time: {time.time()-start}")
data = [] data = []
for embedding in embeddings.tolist(): for embedding in embeddings.tolist():
@ -165,11 +168,13 @@ def remove_punctuations(s, lower=True):
@app.post("/zeroshot") @app.post("/zeroshot")
async def zeroshot(req: ZeroShotRequest, res: Response): async def zeroshot(req: ZeroShotRequest, res: Response):
if req.model not in zero_shot_models: logger.info(f"zero-shot request: {req}")
if req.model != zero_shot_models["model_name"]:
raise HTTPException(status_code=400, detail="unknown model: " + req.model) raise HTTPException(status_code=400, detail="unknown model: " + req.model)
classifier = zero_shot_models[req.model] classifier = zero_shot_models["pipeline"]
labels_without_punctuations = [remove_punctuations(label) for label in req.labels] labels_without_punctuations = [remove_punctuations(label) for label in req.labels]
start = time.time()
predicted_classes = classifier( predicted_classes = classifier(
req.input, candidate_labels=labels_without_punctuations, multi_label=True req.input, candidate_labels=labels_without_punctuations, multi_label=True
) )
@ -178,6 +183,7 @@ async def zeroshot(req: ZeroShotRequest, res: Response):
orig_map = [label_map[label] for label in predicted_classes["labels"]] orig_map = [label_map[label] for label in predicted_classes["labels"]]
final_scores = dict(zip(orig_map, predicted_classes["scores"])) final_scores = dict(zip(orig_map, predicted_classes["scores"]))
predicted_class = label_map[predicted_classes["labels"][0]] predicted_class = label_map[predicted_classes["labels"][0]]
logger.info(f"zero-shot taking {time.time()-start} seconds")
return { return {
"predicted_class": predicted_class, "predicted_class": predicted_class,
@ -201,10 +207,11 @@ async def hallucination(req: HallucinationRequest, res: Response):
example {"name": "John", "age": "25"} example {"name": "John", "age": "25"}
prompt: input prompt from the user prompt: input prompt from the user
""" """
if req.model not in zero_shot_models: if req.model != zero_shot_models["model_name"]:
raise HTTPException(status_code=400, detail="unknown model: " + req.model) raise HTTPException(status_code=400, detail="unknown model: " + req.model)
classifier = zero_shot_models[req.model] start = time.time()
classifier = zero_shot_models["pipeline"]
candidate_labels = [f"{k} is {v}" for k, v in req.parameters.items()] candidate_labels = [f"{k} is {v}" for k, v in req.parameters.items()]
hypothesis_template = "{}" hypothesis_template = "{}"
result = classifier( result = classifier(
@ -215,7 +222,9 @@ async def hallucination(req: HallucinationRequest, res: Response):
) )
result_score = result["scores"] result_score = result["scores"]
result_params = {k[0]: s for k, s in zip(req.parameters.items(), result_score)} result_params = {k[0]: s for k, s in zip(req.parameters.items(), result_score)}
logger.info(f"hallucination result: {result_params}") logger.info(
f"hallucination result: {result_params}, taking {time.time()-start} seconds"
)
return { return {
"params_scores": result_params, "params_scores": result_params,

View file

@ -8,13 +8,12 @@ pyyaml==6.0.2
accelerate accelerate
psutil==6.0.0 psutil==6.0.0
# guard inference packages # guard inference packages
optimum-intel optimum-intel==1.19.0
openvino openvino==2024.4.0
psutil psutil
pandas
dateparser dateparser
openai==1.50.2 openai==1.50.2
pandas pandas
tf-keras tf-keras
onnx onnx==1.17.0
pytest pytest