mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
move demo functions out of model_server (#67)
* pending * remove * fix docker build
This commit is contained in:
parent
ca5c9e4824
commit
31f26ef7ac
9 changed files with 122 additions and 50 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
from openai import OpenAI
|
||||
import gradio as gr
|
||||
import logging as log
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
|
@ -9,6 +10,8 @@ OPEN_API_KEY=os.getenv("OPENAI_API_KEY")
|
|||
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
|
||||
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
|
||||
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
|
||||
client = OpenAI(api_key=OPEN_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT)
|
||||
|
||||
def predict(message, history):
|
||||
|
|
@ -17,16 +20,20 @@ def predict(message, history):
|
|||
# history_openai_format.append({"role": "user", "content": human })
|
||||
# history_openai_format.append({"role": "assistant", "content":assistant})
|
||||
history.append({"role": "user", "content": message})
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
log.info("history: ", history)
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(model='gpt-3.5-turbo',
|
||||
response = client.chat.completions.create(model=MODEL_NAME,
|
||||
messages= history,
|
||||
temperature=1.0
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.info(e)
|
||||
# remove last user message in case of exception
|
||||
history.pop()
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
log.info("Error with OpenAI API: {}".format(e.message))
|
||||
raise gr.Error("Error with OpenAI API: {}".format(e.message))
|
||||
|
||||
# for chunk in response:
|
||||
|
|
@ -52,4 +59,4 @@ with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo:
|
|||
|
||||
txt.submit(predict, [txt, state], [chatbot, state])
|
||||
|
||||
demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True)
|
||||
demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True, debug=True)
|
||||
|
|
|
|||
16
demos/function_calling/api_server/.vscode/launch.json
vendored
Normal file
16
demos/function_calling/api_server/.vscode/launch.json
vendored
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "function-calling api server",
|
||||
"cwd": "${workspaceFolder}/app",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": ["main:app","--reload", "--port", "8001"],
|
||||
}
|
||||
]
|
||||
}
|
||||
19
demos/function_calling/api_server/Dockerfile
Normal file
19
demos/function_calling/api_server/Dockerfile
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
FROM python:3 AS base
|
||||
|
||||
FROM base AS builder
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
COPY requirements.txt /src/
|
||||
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
|
||||
|
||||
COPY . /src
|
||||
|
||||
FROM python:3-slim AS output
|
||||
|
||||
COPY --from=builder /runtime /usr/local
|
||||
|
||||
COPY /app /app
|
||||
WORKDIR /app
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]
|
||||
58
demos/function_calling/api_server/app/main.py
Normal file
58
demos/function_calling/api_server/app/main.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
import random
|
||||
from fastapi import FastAPI, Response
|
||||
from datetime import datetime, date, timedelta, timezone
|
||||
import logging
|
||||
from pydantic import BaseModel
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz():
|
||||
return {
|
||||
"status": "ok"
|
||||
}
|
||||
|
||||
class WeatherRequest(BaseModel):
|
||||
city: str
|
||||
|
||||
|
||||
@app.post("/weather")
|
||||
async def weather(req: WeatherRequest, res: Response):
|
||||
|
||||
weather_forecast = {
|
||||
"city": req.city,
|
||||
"temperature": [],
|
||||
"unit": "F",
|
||||
}
|
||||
for i in range(7):
|
||||
min_temp = random.randrange(50,90)
|
||||
max_temp = random.randrange(min_temp+5, min_temp+20)
|
||||
weather_forecast["temperature"].append({
|
||||
"date": str(date.today() + timedelta(days=i)),
|
||||
"temperature": {
|
||||
"min": min_temp,
|
||||
"max": max_temp
|
||||
}
|
||||
})
|
||||
|
||||
return weather_forecast
|
||||
|
||||
|
||||
class InsuranceClaimDetailsRequest(BaseModel):
|
||||
policy_number: str
|
||||
|
||||
@app.post("/insurance_claim_details")
|
||||
async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Response):
|
||||
|
||||
claim_details = {
|
||||
"policy_number": req.policy_number,
|
||||
"claim_status": "Approved",
|
||||
"claim_amount": random.randrange(1000, 10000),
|
||||
"claim_date": str(date.today() - timedelta(days=random.randrange(1, 30))),
|
||||
"claim_reason": "Car Accident",
|
||||
}
|
||||
|
||||
return claim_details
|
||||
2
demos/function_calling/api_server/requirements.txt
Normal file
2
demos/function_calling/api_server/requirements.txt
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
fastapi
|
||||
uvicorn
|
||||
|
|
@ -27,7 +27,7 @@ prompt_targets:
|
|||
- name: units
|
||||
description: The units in which the weather forecast is requested.
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
cluster: api_server
|
||||
path: /weather
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
|
|
@ -45,12 +45,8 @@ prompt_targets:
|
|||
type: string
|
||||
default: "false"
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
cluster: api_server
|
||||
path: /insurance_claim_details
|
||||
system_prompt: |
|
||||
You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use policy number to retrieve insurance claim details
|
||||
|
||||
clusters:
|
||||
weatherhost:
|
||||
address: model_server
|
||||
|
|
|
|||
|
|
@ -59,6 +59,17 @@ services:
|
|||
# uncomment following line to use ollama endpoint that is hosted by docker
|
||||
# - OLLAMA_ENDPOINT=ollama
|
||||
|
||||
api_server:
|
||||
build:
|
||||
context: api_server
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "18083:80"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
|
||||
interval: 5s
|
||||
retries: 20
|
||||
|
||||
ollama:
|
||||
image: ollama/ollama
|
||||
container_name: ollama
|
||||
|
|
|
|||
|
|
@ -32,6 +32,10 @@
|
|||
"name": "demos/function_calling",
|
||||
"path": "./demos/function_calling",
|
||||
},
|
||||
{
|
||||
"name": "demos/function_calling/api_server",
|
||||
"path": "./demos/function_calling/api_server",
|
||||
},
|
||||
],
|
||||
"settings": {}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -127,32 +127,6 @@ async def zeroshot(req: ZeroShotRequest, res: Response):
|
|||
}
|
||||
|
||||
|
||||
class WeatherRequest(BaseModel):
|
||||
city: str
|
||||
|
||||
|
||||
@app.post("/weather")
|
||||
async def weather(req: WeatherRequest, res: Response):
|
||||
|
||||
weather_forecast = {
|
||||
"city": req.city,
|
||||
"temperature": [],
|
||||
"unit": "F",
|
||||
}
|
||||
for i in range(7):
|
||||
min_temp = random.randrange(50,90)
|
||||
max_temp = random.randrange(min_temp+5, min_temp+20)
|
||||
weather_forecast["temperature"].append({
|
||||
"date": str(date.today() + timedelta(days=i)),
|
||||
"temperature": {
|
||||
"min": min_temp,
|
||||
"max": max_temp
|
||||
}
|
||||
})
|
||||
|
||||
return weather_forecast
|
||||
|
||||
|
||||
'''
|
||||
*****
|
||||
Adding new functions to test the usecases - Sampreeth
|
||||
|
|
@ -309,21 +283,6 @@ async def interface_down_packet_drop(req: PacketDropCorrelationRequest, res: Res
|
|||
logger.info(f"Correlated Packet Drop Data: {correlated_data}")
|
||||
|
||||
return correlated_data.to_dict(orient='records')
|
||||
class InsuranceClaimDetailsRequest(BaseModel):
|
||||
policy_number: str
|
||||
|
||||
@app.post("/insurance_claim_details")
|
||||
async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Response):
|
||||
|
||||
claim_details = {
|
||||
"policy_number": req.policy_number,
|
||||
"claim_status": "Approved",
|
||||
"claim_amount": random.randrange(1000, 10000),
|
||||
"claim_date": str(date.today() - timedelta(days=random.randrange(1, 30))),
|
||||
"claim_reason": "Car Accident",
|
||||
}
|
||||
|
||||
return claim_details
|
||||
|
||||
|
||||
class FlowPacketErrorCorrelationRequest(BaseModel):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue