mirror of
https://github.com/katanemo/plano.git
synced 2026-04-28 10:26:36 +02:00
145 lines
3.6 KiB
Python
145 lines
3.6 KiB
Python
import random
|
|
from fastapi import FastAPI, Response, HTTPException
|
|
from pydantic import BaseModel
|
|
from load_models import load_ner_models, load_transformers, load_zero_shot_models
|
|
from datetime import date, timedelta
|
|
import string
|
|
|
|
transformers = load_transformers()
|
|
ner_models = load_ner_models()
|
|
zero_shot_models = load_zero_shot_models()
|
|
|
|
app = FastAPI()
|
|
|
|
class EmbeddingRequest(BaseModel):
|
|
input: str
|
|
model: str
|
|
|
|
@app.get("/healthz")
|
|
async def healthz():
|
|
return {
|
|
"status": "ok"
|
|
}
|
|
|
|
@app.get("/models")
|
|
async def models():
|
|
models = []
|
|
|
|
for model in transformers.keys():
|
|
models.append({
|
|
"id": model,
|
|
"object": "model"
|
|
})
|
|
|
|
return {
|
|
"data": models,
|
|
"object": "list"
|
|
}
|
|
|
|
@app.post("/embeddings")
|
|
async def embedding(req: EmbeddingRequest, res: Response):
|
|
if req.model not in transformers:
|
|
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
|
|
|
embeddings = transformers[req.model].encode([req.input])
|
|
|
|
data = []
|
|
|
|
for embedding in embeddings.tolist():
|
|
data.append({
|
|
"object": "embedding",
|
|
"embedding": embedding,
|
|
"index": len(data)
|
|
})
|
|
|
|
usage = {
|
|
"prompt_tokens": 0,
|
|
"total_tokens": 0,
|
|
}
|
|
return {
|
|
"data": data,
|
|
"model": req.model,
|
|
"object": "list",
|
|
"usage": usage
|
|
}
|
|
|
|
class NERRequest(BaseModel):
|
|
input: str
|
|
labels: list[str]
|
|
model: str
|
|
|
|
|
|
@app.post("/ner")
|
|
async def ner(req: NERRequest, res: Response):
|
|
if req.model not in ner_models:
|
|
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
|
|
|
model = ner_models[req.model]
|
|
entities = model.predict_entities(req.input, req.labels)
|
|
|
|
return {
|
|
"data": entities,
|
|
"model": req.model,
|
|
"object": "list",
|
|
}
|
|
|
|
class ZeroShotRequest(BaseModel):
|
|
input: str
|
|
labels: list[str]
|
|
model: str
|
|
|
|
|
|
def remove_punctuations(s, lower=True):
|
|
s = s.translate(str.maketrans(string.punctuation, " " * len(string.punctuation)))
|
|
s = " ".join(s.split())
|
|
if lower:
|
|
s = s.lower()
|
|
return s
|
|
|
|
|
|
@app.post("/zeroshot")
|
|
async def zeroshot(req: ZeroShotRequest, res: Response):
|
|
if req.model not in zero_shot_models:
|
|
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
|
|
|
classifier = zero_shot_models[req.model]
|
|
labels_without_punctuations = [remove_punctuations(label) for label in req.labels]
|
|
predicted_classes = classifier(req.input, candidate_labels=labels_without_punctuations, multi_label=True)
|
|
label_map = dict(zip(labels_without_punctuations, req.labels))
|
|
|
|
orig_map = [label_map[label] for label in predicted_classes["labels"]]
|
|
final_scores = dict(zip(orig_map, predicted_classes["scores"]))
|
|
predicted_class = label_map[predicted_classes["labels"][0]]
|
|
|
|
return {
|
|
"predicted_class": predicted_class,
|
|
"predicted_class_score": final_scores[predicted_class],
|
|
"scores": final_scores,
|
|
"model": req.model,
|
|
}
|
|
|
|
|
|
class WeatherRequest(BaseModel):
|
|
city: str
|
|
|
|
|
|
@app.post("/weather")
|
|
async def weather(req: WeatherRequest, res: Response):
|
|
|
|
weather_forecast = {
|
|
"city": req.city,
|
|
"temperature": [],
|
|
"unit": "F",
|
|
}
|
|
for i in range(7):
|
|
min_temp = random.randrange(50,90)
|
|
max_temp = random.randrange(min_temp+5, min_temp+20)
|
|
weather_forecast["temperature"].append({
|
|
"date": str(date.today() + timedelta(days=i)),
|
|
"temperature": {
|
|
"min": min_temp,
|
|
"max": max_temp
|
|
}
|
|
})
|
|
|
|
return weather_forecast
|