mirror of
https://github.com/katanemo/plano.git
synced 2026-05-01 11:56:29 +02:00
Add workflow logic for weather forecast demo (#24)
This commit is contained in:
parent
7ef68eccfb
commit
33f9dd22e6
32 changed files with 1902 additions and 459 deletions
|
|
@ -1,8 +1,11 @@
|
|||
import random
|
||||
from fastapi import FastAPI, Response, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from load_transformers import load_transformers
|
||||
from load_models import load_ner_models, load_transformers
|
||||
from datetime import date, timedelta
|
||||
|
||||
transformers = load_transformers()
|
||||
ner_models = load_ner_models()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
|
@ -10,6 +13,12 @@ class EmbeddingRequest(BaseModel):
|
|||
input: str
|
||||
model: str
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz():
|
||||
return {
|
||||
"status": "ok"
|
||||
}
|
||||
|
||||
@app.get("/models")
|
||||
async def models():
|
||||
models = []
|
||||
|
|
@ -27,7 +36,7 @@ async def models():
|
|||
|
||||
@app.post("/embeddings")
|
||||
async def embedding(req: EmbeddingRequest, res: Response):
|
||||
if not req.model in transformers:
|
||||
if req.model not in transformers:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
embeddings = transformers[req.model].encode([req.input])
|
||||
|
|
@ -51,3 +60,48 @@ async def embedding(req: EmbeddingRequest, res: Response):
|
|||
"object": "list",
|
||||
"usage": usage
|
||||
}
|
||||
|
||||
class NERRequest(BaseModel):
|
||||
input: str
|
||||
labels: list[str]
|
||||
model: str
|
||||
|
||||
|
||||
@app.post("/ner")
|
||||
async def ner(req: NERRequest, res: Response):
|
||||
if req.model not in ner_models:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
model = ner_models[req.model]
|
||||
entities = model.predict_entities(req.input, req.labels)
|
||||
|
||||
return {
|
||||
"data": entities,
|
||||
"model": req.model,
|
||||
"object": "list",
|
||||
}
|
||||
|
||||
class WeatherRequest(BaseModel):
|
||||
city: str
|
||||
|
||||
|
||||
@app.post("/weather")
|
||||
async def weather(req: WeatherRequest, res: Response):
|
||||
|
||||
weather_forecast = {
|
||||
"city": req.city,
|
||||
"temperature": [],
|
||||
"unit": "F",
|
||||
}
|
||||
for i in range(7):
|
||||
min_temp = random.randrange(50,90)
|
||||
max_temp = random.randrange(min_temp+5, min_temp+20)
|
||||
weather_forecast["temperature"].append({
|
||||
"date": str(date.today() + timedelta(days=i)),
|
||||
"temperature": {
|
||||
"min": min_temp,
|
||||
"max": max_temp
|
||||
}
|
||||
})
|
||||
|
||||
return weather_forecast
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue