Updated hr_agent to be full stack: gradio + fastAPI (#235)

* commiting to remove

* fix

* updating hr_agent

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-261.local>
Co-authored-by: Adil Hafeez <adil@katanemo.com>
This commit is contained in:
Salman Paracha 2024-10-30 15:05:34 -07:00 committed by GitHub
parent bb9a774a72
commit bb882fb59b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 157 additions and 71 deletions

View file

@ -834,7 +834,7 @@ impl StreamContext {
);
debug!(
"archgw => api call, endpoint: {}/{}, body: {}",
"archgw => api call, endpoint: {}{}, body: {}",
endpoint.name.as_str(),
path,
tool_params_json_str

View file

@ -5,16 +5,12 @@ FROM base AS builder
WORKDIR /src
COPY requirements.txt /src/
COPY workforce_data.json /src/
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
COPY . /src
FROM python:3.10-slim AS output
COPY --from=builder /runtime /usr/local
COPY . /app
WORKDIR /app
COPY . /app
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--log-level", "info"]

View file

@ -27,12 +27,6 @@ system_prompt: |
You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workfoce planning. NOTHING ELSE. When you get data in json format, offer some summary but don't be too verbose.
prompt_targets:
- name: hr_qa
endpoint:
name: app_server
path: /agent/hr_qa
description: Handle general Q/A related to HR.
default: true
- name: workforce
description: Get workforce data like headcount and satisfacton levels by region and staffing type
endpoint:
@ -47,10 +41,10 @@ prompt_targets:
type: str
required: true
description: Geographical region for which you want workforce data like asia, europe, americas.
- name: point_in_time
- name: data_snapshot_days_ago
type: int
required: false
description: the point in time for which to retrieve data. For e.g 0 days ago, 30 days ago, etc.
description: the snapshot day for which you want workforce data.
- name: slack_message
endpoint:
name: app_server

View file

@ -4,22 +4,14 @@ services:
context: .
environment:
- SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN:-None}
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
volumes:
- ./arch_config.yaml:/app/arch_config.yaml
- ../shared/chatbot_ui/common.py:/app/common.py
ports:
- "18083:80"
healthcheck:
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
interval: 5s
retries: 20
chatbot_ui:
build:
context: ../shared/chatbot_ui
ports:
- "18080:8080"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- ./arch_config.yaml:/app/arch_config.yaml

View file

@ -1,20 +1,31 @@
import os
import json
import pandas as pd
import gradio as gr
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import Optional
from enum import Enum
from typing import List, Optional, Tuple
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from openai import OpenAI
from common import create_gradio_app
app = FastAPI()
workforce_data_df = None
demo_description = """This demo showcases how the **Arch** can be used to build an
HR agent to manage workforce-related inquiries, workforce planning, and communication via Slack.
It intelligently routes incoming prompts to the correct targets, providing concise and useful responses
tailored for HR and workforce decision-making. """
with open("workforce_data.json") as file:
workforce_data = json.load(file)
workforce_data_df = pd.json_normalize(
workforce_data, record_path=["regions"], meta=["point_in_time", "satisfaction"]
workforce_data,
record_path=["regions"],
meta=["data_snapshot_days_ago", "satisfaction"],
)
@ -22,7 +33,7 @@ with open("workforce_data.json") as file:
class WorkforceRequset(BaseModel):
region: str
staffing_type: str
point_in_time: Optional[int] = None
data_snapshot_days_ago: Optional[int] = None
class SlackRequest(BaseModel):
@ -36,25 +47,6 @@ class WorkforceResponse(BaseModel):
satisfaction: float
# Post method for device summary
@app.post("/agent/workforce")
def get_workforce(request: WorkforceRequset):
"""
Endpoint to workforce data by region, staffing type at a given point in time.
"""
region = request.region.lower()
staffing_type = request.staffing_type.lower()
point_in_time = request.point_in_time if request.point_in_time else 0
response = {
"region": region,
"staffing_type": f"Staffing agency: {staffing_type}",
"headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['point_in_time']==point_in_time)][staffing_type].values[0])}",
"satisfaction": f"Satisifaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['point_in_time']==point_in_time)]['satisfaction'].values[0])}",
}
return response
@app.post("/agent/slack_message")
def send_slack_message(request: SlackRequest):
"""
@ -80,27 +72,38 @@ def send_slack_message(request: SlackRequest):
print(f"Error sending message: {e.response['error']}")
@app.post("/agent/hr_qa")
async def general_hr_qa():
# Post method for device summary
@app.post("/agent/workforce")
def get_workforce(request: WorkforceRequset):
"""
This method handles Q/A related to general issues in HR.
It forwards the conversation to the OpenAI client via a local proxy and returns the response.
Endpoint to workforce data by region, staffing type at a given point in time.
"""
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "I am a helpful HR agent, and I can help you plan for workforce related questions",
},
"finish_reason": "completed",
"index": 0,
}
],
"model": "hr_agent",
"usage": {"completion_tokens": 0},
}
region = request.region.lower()
staffing_type = request.staffing_type.lower()
data_snapshot_days_ago = (
request.data_snapshot_days_ago
if request.data_snapshot_days_ago
else 0 # this param is not required.
)
response = {
"region": region,
"staffing_type": f"Staffing agency: {staffing_type}",
"headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)][staffing_type].values[0])}",
"satisfaction": f"Satisifaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)]['satisfaction'].values[0])}",
}
return response
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
client = OpenAI(
api_key="--",
base_url=CHAT_COMPLETION_ENDPOINT,
)
gr.mount_gradio_app(
app, create_gradio_app(demo_description, client), path="/agent/chat"
)
if __name__ == "__main__":
app.run(debug=True)

View file

@ -1,6 +1,13 @@
fastapi
uvicorn
pydantic
slack-sdk
typing
pandas
gradio==5.3.0
async_timeout==4.0.3
loguru==0.7.2
asyncio==3.4.3
httpx==0.27.0
python-dotenv==1.0.1
pydantic==2.8.2
openai==1.51.0

View file

@ -1,6 +1,6 @@
[
{
"point_in_time": 0,
"data_snapshot_days_ago": 0,
"regions": [
{ "region": "asia", "contract": 100, "fte": 150, "agency": 2000 },
{ "region": "europe", "contract": 80, "fte": 120, "agency": 2500 },
@ -9,7 +9,7 @@
"satisfaction": 3.5
},
{
"point_in_time": 30,
"data_snapshot_days_ago": 30,
"regions": [
{ "region": "asia", "contract": 110, "fte": 155, "agency": 1000 },
{ "region": "europe", "contract": 85, "fte": 130, "agency": 1600 },
@ -18,7 +18,7 @@
"satisfaction": 4.0
},
{
"point_in_time": 60,
"data_snapshot_days_ago": 60,
"regions": [
{ "region": "asia", "contract": 115, "fte": 160, "agency": 500 },
{ "region": "europe", "contract": 90, "fte": 140, "agency": 700 },

View file

@ -2,6 +2,9 @@ import json
import logging
import os
import yaml
import gradio as gr
from typing import List, Optional, Tuple
from functools import partial
logging.basicConfig(
level=logging.INFO,
@ -10,6 +13,97 @@ logging.basicConfig(
log = logging.getLogger(__name__)
GRADIO_CSS_STYLE = """
.json-container {
height: 95vh !important;
overflow-y: auto !important;
}
.chatbot {
height: calc(95vh - 100px) !important;
overflow-y: auto !important;
}
footer {visibility: hidden}
"""
def chat(
query: Optional[str],
conversation: Optional[List[Tuple[str, str]]],
history: List[dict],
client,
):
history.append({"role": "user", "content": query})
try:
response = client.chat.completions.create(
# we select model from arch_config file
model="--",
messages=history,
temperature=1.0,
stream=True,
)
except Exception as e:
# remove last user message in case of exception
history.pop()
log.info("Error calling gateway API: {}".format(e))
raise gr.Error("Error calling gateway API: {}".format(e))
conversation.append((query, ""))
for chunk in response:
tokens = process_stream_chunk(chunk, history)
if tokens:
conversation[-1] = (
conversation[-1][0],
conversation[-1][1] + tokens,
)
yield "", conversation, history
def create_gradio_app(demo_description, client):
with gr.Blocks(
theme=gr.themes.Default(
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "Arial", "sans-serif"]
),
fill_height=True,
css=GRADIO_CSS_STYLE,
) as demo:
with gr.Row(equal_height=True):
history = gr.State([])
with gr.Column(scale=1):
gr.Markdown(demo_description),
with gr.Accordion("Available Tools/APIs", open=True):
with gr.Column(scale=1):
gr.JSON(
value=get_prompt_targets(),
show_indices=False,
elem_classes="json-container",
min_height="80vh",
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Arch Chatbot",
elem_classes="chatbot",
)
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
autofocus=True,
elem_classes="textbox",
)
chat_with_client = partial(chat, client=client)
textbox.submit(
chat_with_client,
[textbox, chatbot, history],
[textbox, chatbot, history],
)
return demo
def process_stream_chunk(chunk, history):
delta = chunk.choices[0].delta