Improve prompt target intent matching (#51)

This commit is contained in:
Adil Hafeez 2024-09-16 19:20:07 -07:00 committed by GitHub
parent 8565462ec4
commit 9e50957f22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 461 additions and 415 deletions

View file

@ -1,6 +1,7 @@
import os
import sentence_transformers
from gliner import GLiNER
from transformers import pipeline
def load_transformers(models = os.getenv("MODELS", "BAAI/bge-large-en-v1.5")):
transformers = {}
@ -17,3 +18,11 @@ def load_ner_models(models = os.getenv("NER_MODELS", "urchade/gliner_large-v2.1"
ner_models[model] = GLiNER.from_pretrained(model)
return ner_models
def load_zero_shot_models(models = os.getenv("ZERO_SHOT_MODELS", "tasksource/deberta-base-long-nli")):
zero_shot_models = {}
for model in models.split(','):
zero_shot_models[model] = pipeline("zero-shot-classification",model=model)
return zero_shot_models

View file

@ -1,11 +1,13 @@
import random
from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel
from load_models import load_ner_models, load_transformers
from load_models import load_ner_models, load_transformers, load_zero_shot_models
from datetime import date, timedelta
import string
transformers = load_transformers()
ner_models = load_ner_models()
zero_shot_models = load_zero_shot_models()
app = FastAPI()
@ -81,6 +83,42 @@ async def ner(req: NERRequest, res: Response):
"object": "list",
}
class ZeroShotRequest(BaseModel):
input: str
labels: list[str]
model: str
def remove_punctuations(s, lower=True):
s = s.translate(str.maketrans(string.punctuation, " " * len(string.punctuation)))
s = " ".join(s.split())
if lower:
s = s.lower()
return s
@app.post("/zeroshot")
async def zeroshot(req: ZeroShotRequest, res: Response):
if req.model not in zero_shot_models:
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
classifier = zero_shot_models[req.model]
labels_without_punctuations = [remove_punctuations(label) for label in req.labels]
predicted_classes = classifier(req.input, candidate_labels=labels_without_punctuations, multi_label=True)
label_map = dict(zip(labels_without_punctuations, req.labels))
orig_map = [label_map[label] for label in predicted_classes["labels"]]
final_scores = dict(zip(orig_map, predicted_classes["scores"]))
predicted_class = label_map[predicted_classes["labels"][0]]
return {
"predicted_class": predicted_class,
"predicted_class_score": final_scores[predicted_class],
"scores": final_scores,
"model": req.model,
}
class WeatherRequest(BaseModel):
city: str