mirror of
https://github.com/katanemo/plano.git
synced 2026-05-09 07:42:43 +02:00
Updated hr_agent to be full stack: gradio + fastAPI (#235)
* commiting to remove * fix * updating hr_agent --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-261.local> Co-authored-by: Adil Hafeez <adil@katanemo.com>
This commit is contained in:
parent
bb9a774a72
commit
bb882fb59b
8 changed files with 157 additions and 71 deletions
|
|
@ -834,7 +834,7 @@ impl StreamContext {
|
|||
);
|
||||
|
||||
debug!(
|
||||
"archgw => api call, endpoint: {}/{}, body: {}",
|
||||
"archgw => api call, endpoint: {}{}, body: {}",
|
||||
endpoint.name.as_str(),
|
||||
path,
|
||||
tool_params_json_str
|
||||
|
|
|
|||
|
|
@ -5,16 +5,12 @@ FROM base AS builder
|
|||
WORKDIR /src
|
||||
|
||||
COPY requirements.txt /src/
|
||||
COPY workforce_data.json /src/
|
||||
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
|
||||
|
||||
COPY . /src
|
||||
|
||||
FROM python:3.10-slim AS output
|
||||
|
||||
COPY --from=builder /runtime /usr/local
|
||||
|
||||
COPY . /app
|
||||
WORKDIR /app
|
||||
COPY . /app
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--log-level", "info"]
|
||||
|
|
|
|||
|
|
@ -27,12 +27,6 @@ system_prompt: |
|
|||
You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workfoce planning. NOTHING ELSE. When you get data in json format, offer some summary but don't be too verbose.
|
||||
|
||||
prompt_targets:
|
||||
- name: hr_qa
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/hr_qa
|
||||
description: Handle general Q/A related to HR.
|
||||
default: true
|
||||
- name: workforce
|
||||
description: Get workforce data like headcount and satisfacton levels by region and staffing type
|
||||
endpoint:
|
||||
|
|
@ -47,10 +41,10 @@ prompt_targets:
|
|||
type: str
|
||||
required: true
|
||||
description: Geographical region for which you want workforce data like asia, europe, americas.
|
||||
- name: point_in_time
|
||||
- name: data_snapshot_days_ago
|
||||
type: int
|
||||
required: false
|
||||
description: the point in time for which to retrieve data. For e.g 0 days ago, 30 days ago, etc.
|
||||
description: the snapshot day for which you want workforce data.
|
||||
- name: slack_message
|
||||
endpoint:
|
||||
name: app_server
|
||||
|
|
|
|||
|
|
@ -4,22 +4,14 @@ services:
|
|||
context: .
|
||||
environment:
|
||||
- SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN:-None}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
- ../shared/chatbot_ui/common.py:/app/common.py
|
||||
ports:
|
||||
- "18083:80"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
|
||||
interval: 5s
|
||||
retries: 20
|
||||
|
||||
chatbot_ui:
|
||||
build:
|
||||
context: ../shared/chatbot_ui
|
||||
ports:
|
||||
- "18080:8080"
|
||||
environment:
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
|
|
|
|||
|
|
@ -1,20 +1,31 @@
|
|||
import os
|
||||
import json
|
||||
import pandas as pd
|
||||
import gradio as gr
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from openai import OpenAI
|
||||
from common import create_gradio_app
|
||||
|
||||
app = FastAPI()
|
||||
workforce_data_df = None
|
||||
demo_description = """This demo showcases how the **Arch** can be used to build an
|
||||
HR agent to manage workforce-related inquiries, workforce planning, and communication via Slack.
|
||||
It intelligently routes incoming prompts to the correct targets, providing concise and useful responses
|
||||
tailored for HR and workforce decision-making. """
|
||||
|
||||
with open("workforce_data.json") as file:
|
||||
workforce_data = json.load(file)
|
||||
workforce_data_df = pd.json_normalize(
|
||||
workforce_data, record_path=["regions"], meta=["point_in_time", "satisfaction"]
|
||||
workforce_data,
|
||||
record_path=["regions"],
|
||||
meta=["data_snapshot_days_ago", "satisfaction"],
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -22,7 +33,7 @@ with open("workforce_data.json") as file:
|
|||
class WorkforceRequset(BaseModel):
|
||||
region: str
|
||||
staffing_type: str
|
||||
point_in_time: Optional[int] = None
|
||||
data_snapshot_days_ago: Optional[int] = None
|
||||
|
||||
|
||||
class SlackRequest(BaseModel):
|
||||
|
|
@ -36,25 +47,6 @@ class WorkforceResponse(BaseModel):
|
|||
satisfaction: float
|
||||
|
||||
|
||||
# Post method for device summary
|
||||
@app.post("/agent/workforce")
|
||||
def get_workforce(request: WorkforceRequset):
|
||||
"""
|
||||
Endpoint to workforce data by region, staffing type at a given point in time.
|
||||
"""
|
||||
region = request.region.lower()
|
||||
staffing_type = request.staffing_type.lower()
|
||||
point_in_time = request.point_in_time if request.point_in_time else 0
|
||||
|
||||
response = {
|
||||
"region": region,
|
||||
"staffing_type": f"Staffing agency: {staffing_type}",
|
||||
"headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['point_in_time']==point_in_time)][staffing_type].values[0])}",
|
||||
"satisfaction": f"Satisifaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['point_in_time']==point_in_time)]['satisfaction'].values[0])}",
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
@app.post("/agent/slack_message")
|
||||
def send_slack_message(request: SlackRequest):
|
||||
"""
|
||||
|
|
@ -80,27 +72,38 @@ def send_slack_message(request: SlackRequest):
|
|||
print(f"Error sending message: {e.response['error']}")
|
||||
|
||||
|
||||
@app.post("/agent/hr_qa")
|
||||
async def general_hr_qa():
|
||||
# Post method for device summary
|
||||
@app.post("/agent/workforce")
|
||||
def get_workforce(request: WorkforceRequset):
|
||||
"""
|
||||
This method handles Q/A related to general issues in HR.
|
||||
It forwards the conversation to the OpenAI client via a local proxy and returns the response.
|
||||
Endpoint to workforce data by region, staffing type at a given point in time.
|
||||
"""
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I am a helpful HR agent, and I can help you plan for workforce related questions",
|
||||
},
|
||||
"finish_reason": "completed",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"model": "hr_agent",
|
||||
"usage": {"completion_tokens": 0},
|
||||
}
|
||||
region = request.region.lower()
|
||||
staffing_type = request.staffing_type.lower()
|
||||
data_snapshot_days_ago = (
|
||||
request.data_snapshot_days_ago
|
||||
if request.data_snapshot_days_ago
|
||||
else 0 # this param is not required.
|
||||
)
|
||||
|
||||
response = {
|
||||
"region": region,
|
||||
"staffing_type": f"Staffing agency: {staffing_type}",
|
||||
"headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)][staffing_type].values[0])}",
|
||||
"satisfaction": f"Satisifaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)]['satisfaction'].values[0])}",
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
|
||||
client = OpenAI(
|
||||
api_key="--",
|
||||
base_url=CHAT_COMPLETION_ENDPOINT,
|
||||
)
|
||||
|
||||
gr.mount_gradio_app(
|
||||
app, create_gradio_app(demo_description, client), path="/agent/chat"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=True)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,13 @@
|
|||
fastapi
|
||||
uvicorn
|
||||
pydantic
|
||||
slack-sdk
|
||||
typing
|
||||
pandas
|
||||
gradio==5.3.0
|
||||
async_timeout==4.0.3
|
||||
loguru==0.7.2
|
||||
asyncio==3.4.3
|
||||
httpx==0.27.0
|
||||
python-dotenv==1.0.1
|
||||
pydantic==2.8.2
|
||||
openai==1.51.0
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[
|
||||
{
|
||||
"point_in_time": 0,
|
||||
"data_snapshot_days_ago": 0,
|
||||
"regions": [
|
||||
{ "region": "asia", "contract": 100, "fte": 150, "agency": 2000 },
|
||||
{ "region": "europe", "contract": 80, "fte": 120, "agency": 2500 },
|
||||
|
|
@ -9,7 +9,7 @@
|
|||
"satisfaction": 3.5
|
||||
},
|
||||
{
|
||||
"point_in_time": 30,
|
||||
"data_snapshot_days_ago": 30,
|
||||
"regions": [
|
||||
{ "region": "asia", "contract": 110, "fte": 155, "agency": 1000 },
|
||||
{ "region": "europe", "contract": 85, "fte": 130, "agency": 1600 },
|
||||
|
|
@ -18,7 +18,7 @@
|
|||
"satisfaction": 4.0
|
||||
},
|
||||
{
|
||||
"point_in_time": 60,
|
||||
"data_snapshot_days_ago": 60,
|
||||
"regions": [
|
||||
{ "region": "asia", "contract": 115, "fte": 160, "agency": 500 },
|
||||
{ "region": "europe", "contract": 90, "fte": 140, "agency": 700 },
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import yaml
|
||||
import gradio as gr
|
||||
from typing import List, Optional, Tuple
|
||||
from functools import partial
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -10,6 +13,97 @@ logging.basicConfig(
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
GRADIO_CSS_STYLE = """
|
||||
.json-container {
|
||||
height: 95vh !important;
|
||||
overflow-y: auto !important;
|
||||
}
|
||||
.chatbot {
|
||||
height: calc(95vh - 100px) !important;
|
||||
overflow-y: auto !important;
|
||||
}
|
||||
footer {visibility: hidden}
|
||||
"""
|
||||
|
||||
|
||||
def chat(
|
||||
query: Optional[str],
|
||||
conversation: Optional[List[Tuple[str, str]]],
|
||||
history: List[dict],
|
||||
client,
|
||||
):
|
||||
history.append({"role": "user", "content": query})
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
# we select model from arch_config file
|
||||
model="--",
|
||||
messages=history,
|
||||
temperature=1.0,
|
||||
stream=True,
|
||||
)
|
||||
except Exception as e:
|
||||
# remove last user message in case of exception
|
||||
history.pop()
|
||||
log.info("Error calling gateway API: {}".format(e))
|
||||
raise gr.Error("Error calling gateway API: {}".format(e))
|
||||
|
||||
conversation.append((query, ""))
|
||||
|
||||
for chunk in response:
|
||||
tokens = process_stream_chunk(chunk, history)
|
||||
if tokens:
|
||||
conversation[-1] = (
|
||||
conversation[-1][0],
|
||||
conversation[-1][1] + tokens,
|
||||
)
|
||||
|
||||
yield "", conversation, history
|
||||
|
||||
|
||||
def create_gradio_app(demo_description, client):
|
||||
with gr.Blocks(
|
||||
theme=gr.themes.Default(
|
||||
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "Arial", "sans-serif"]
|
||||
),
|
||||
fill_height=True,
|
||||
css=GRADIO_CSS_STYLE,
|
||||
) as demo:
|
||||
with gr.Row(equal_height=True):
|
||||
history = gr.State([])
|
||||
|
||||
with gr.Column(scale=1):
|
||||
gr.Markdown(demo_description),
|
||||
with gr.Accordion("Available Tools/APIs", open=True):
|
||||
with gr.Column(scale=1):
|
||||
gr.JSON(
|
||||
value=get_prompt_targets(),
|
||||
show_indices=False,
|
||||
elem_classes="json-container",
|
||||
min_height="80vh",
|
||||
)
|
||||
|
||||
with gr.Column(scale=2):
|
||||
chatbot = gr.Chatbot(
|
||||
label="Arch Chatbot",
|
||||
elem_classes="chatbot",
|
||||
)
|
||||
textbox = gr.Textbox(
|
||||
show_label=False,
|
||||
placeholder="Enter text and press enter",
|
||||
autofocus=True,
|
||||
elem_classes="textbox",
|
||||
)
|
||||
chat_with_client = partial(chat, client=client)
|
||||
|
||||
textbox.submit(
|
||||
chat_with_client,
|
||||
[textbox, chatbot, history],
|
||||
[textbox, chatbot, history],
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
def process_stream_chunk(chunk, history):
|
||||
delta = chunk.choices[0].delta
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue