Updated hr_agent to be full stack: gradio + fastAPI (#235)

* commiting to remove

* fix

* updating hr_agent

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-261.local>
Co-authored-by: Adil Hafeez <adil@katanemo.com>
This commit is contained in:
Salman Paracha 2024-10-30 15:05:34 -07:00 committed by GitHub
parent bb9a774a72
commit bb882fb59b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 157 additions and 71 deletions

View file

@ -834,7 +834,7 @@ impl StreamContext {
); );
debug!( debug!(
"archgw => api call, endpoint: {}/{}, body: {}", "archgw => api call, endpoint: {}{}, body: {}",
endpoint.name.as_str(), endpoint.name.as_str(),
path, path,
tool_params_json_str tool_params_json_str

View file

@ -5,16 +5,12 @@ FROM base AS builder
WORKDIR /src WORKDIR /src
COPY requirements.txt /src/ COPY requirements.txt /src/
COPY workforce_data.json /src/
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
COPY . /src
FROM python:3.10-slim AS output FROM python:3.10-slim AS output
COPY --from=builder /runtime /usr/local COPY --from=builder /runtime /usr/local
COPY . /app
WORKDIR /app WORKDIR /app
COPY . /app
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--log-level", "info"] CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--log-level", "info"]

View file

@ -27,12 +27,6 @@ system_prompt: |
You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workfoce planning. NOTHING ELSE. When you get data in json format, offer some summary but don't be too verbose. You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workfoce planning. NOTHING ELSE. When you get data in json format, offer some summary but don't be too verbose.
prompt_targets: prompt_targets:
- name: hr_qa
endpoint:
name: app_server
path: /agent/hr_qa
description: Handle general Q/A related to HR.
default: true
- name: workforce - name: workforce
description: Get workforce data like headcount and satisfacton levels by region and staffing type description: Get workforce data like headcount and satisfacton levels by region and staffing type
endpoint: endpoint:
@ -47,10 +41,10 @@ prompt_targets:
type: str type: str
required: true required: true
description: Geographical region for which you want workforce data like asia, europe, americas. description: Geographical region for which you want workforce data like asia, europe, americas.
- name: point_in_time - name: data_snapshot_days_ago
type: int type: int
required: false required: false
description: the point in time for which to retrieve data. For e.g 0 days ago, 30 days ago, etc. description: the snapshot day for which you want workforce data.
- name: slack_message - name: slack_message
endpoint: endpoint:
name: app_server name: app_server

View file

@ -4,22 +4,14 @@ services:
context: . context: .
environment: environment:
- SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN:-None} - SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN:-None}
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
volumes:
- ./arch_config.yaml:/app/arch_config.yaml
- ../shared/chatbot_ui/common.py:/app/common.py
ports: ports:
- "18083:80" - "18083:80"
healthcheck: healthcheck:
test: ["CMD", "curl" ,"http://localhost:80/healthz"] test: ["CMD", "curl" ,"http://localhost:80/healthz"]
interval: 5s interval: 5s
retries: 20 retries: 20
chatbot_ui:
build:
context: ../shared/chatbot_ui
ports:
- "18080:8080"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- ./arch_config.yaml:/app/arch_config.yaml

View file

@ -1,20 +1,31 @@
import os import os
import json import json
import pandas as pd import pandas as pd
import gradio as gr
import logging
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional
from enum import Enum from enum import Enum
from typing import List, Optional, Tuple
from slack_sdk import WebClient from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError from slack_sdk.errors import SlackApiError
from openai import OpenAI
from common import create_gradio_app
app = FastAPI() app = FastAPI()
workforce_data_df = None workforce_data_df = None
demo_description = """This demo showcases how the **Arch** can be used to build an
HR agent to manage workforce-related inquiries, workforce planning, and communication via Slack.
It intelligently routes incoming prompts to the correct targets, providing concise and useful responses
tailored for HR and workforce decision-making. """
with open("workforce_data.json") as file: with open("workforce_data.json") as file:
workforce_data = json.load(file) workforce_data = json.load(file)
workforce_data_df = pd.json_normalize( workforce_data_df = pd.json_normalize(
workforce_data, record_path=["regions"], meta=["point_in_time", "satisfaction"] workforce_data,
record_path=["regions"],
meta=["data_snapshot_days_ago", "satisfaction"],
) )
@ -22,7 +33,7 @@ with open("workforce_data.json") as file:
class WorkforceRequset(BaseModel): class WorkforceRequset(BaseModel):
region: str region: str
staffing_type: str staffing_type: str
point_in_time: Optional[int] = None data_snapshot_days_ago: Optional[int] = None
class SlackRequest(BaseModel): class SlackRequest(BaseModel):
@ -36,25 +47,6 @@ class WorkforceResponse(BaseModel):
satisfaction: float satisfaction: float
# Post method for device summary
@app.post("/agent/workforce")
def get_workforce(request: WorkforceRequset):
"""
Endpoint to workforce data by region, staffing type at a given point in time.
"""
region = request.region.lower()
staffing_type = request.staffing_type.lower()
point_in_time = request.point_in_time if request.point_in_time else 0
response = {
"region": region,
"staffing_type": f"Staffing agency: {staffing_type}",
"headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['point_in_time']==point_in_time)][staffing_type].values[0])}",
"satisfaction": f"Satisifaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['point_in_time']==point_in_time)]['satisfaction'].values[0])}",
}
return response
@app.post("/agent/slack_message") @app.post("/agent/slack_message")
def send_slack_message(request: SlackRequest): def send_slack_message(request: SlackRequest):
""" """
@ -80,27 +72,38 @@ def send_slack_message(request: SlackRequest):
print(f"Error sending message: {e.response['error']}") print(f"Error sending message: {e.response['error']}")
@app.post("/agent/hr_qa") # Post method for device summary
async def general_hr_qa(): @app.post("/agent/workforce")
def get_workforce(request: WorkforceRequset):
""" """
This method handles Q/A related to general issues in HR. Endpoint to workforce data by region, staffing type at a given point in time.
It forwards the conversation to the OpenAI client via a local proxy and returns the response.
""" """
return { region = request.region.lower()
"choices": [ staffing_type = request.staffing_type.lower()
{ data_snapshot_days_ago = (
"message": { request.data_snapshot_days_ago
"role": "assistant", if request.data_snapshot_days_ago
"content": "I am a helpful HR agent, and I can help you plan for workforce related questions", else 0 # this param is not required.
}, )
"finish_reason": "completed",
"index": 0,
}
],
"model": "hr_agent",
"usage": {"completion_tokens": 0},
}
response = {
"region": region,
"staffing_type": f"Staffing agency: {staffing_type}",
"headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)][staffing_type].values[0])}",
"satisfaction": f"Satisifaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)]['satisfaction'].values[0])}",
}
return response
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
client = OpenAI(
api_key="--",
base_url=CHAT_COMPLETION_ENDPOINT,
)
gr.mount_gradio_app(
app, create_gradio_app(demo_description, client), path="/agent/chat"
)
if __name__ == "__main__": if __name__ == "__main__":
app.run(debug=True) app.run(debug=True)

View file

@ -1,6 +1,13 @@
fastapi fastapi
uvicorn uvicorn
pydantic
slack-sdk slack-sdk
typing typing
pandas pandas
gradio==5.3.0
async_timeout==4.0.3
loguru==0.7.2
asyncio==3.4.3
httpx==0.27.0
python-dotenv==1.0.1
pydantic==2.8.2
openai==1.51.0

View file

@ -1,6 +1,6 @@
[ [
{ {
"point_in_time": 0, "data_snapshot_days_ago": 0,
"regions": [ "regions": [
{ "region": "asia", "contract": 100, "fte": 150, "agency": 2000 }, { "region": "asia", "contract": 100, "fte": 150, "agency": 2000 },
{ "region": "europe", "contract": 80, "fte": 120, "agency": 2500 }, { "region": "europe", "contract": 80, "fte": 120, "agency": 2500 },
@ -9,7 +9,7 @@
"satisfaction": 3.5 "satisfaction": 3.5
}, },
{ {
"point_in_time": 30, "data_snapshot_days_ago": 30,
"regions": [ "regions": [
{ "region": "asia", "contract": 110, "fte": 155, "agency": 1000 }, { "region": "asia", "contract": 110, "fte": 155, "agency": 1000 },
{ "region": "europe", "contract": 85, "fte": 130, "agency": 1600 }, { "region": "europe", "contract": 85, "fte": 130, "agency": 1600 },
@ -18,7 +18,7 @@
"satisfaction": 4.0 "satisfaction": 4.0
}, },
{ {
"point_in_time": 60, "data_snapshot_days_ago": 60,
"regions": [ "regions": [
{ "region": "asia", "contract": 115, "fte": 160, "agency": 500 }, { "region": "asia", "contract": 115, "fte": 160, "agency": 500 },
{ "region": "europe", "contract": 90, "fte": 140, "agency": 700 }, { "region": "europe", "contract": 90, "fte": 140, "agency": 700 },

View file

@ -2,6 +2,9 @@ import json
import logging import logging
import os import os
import yaml import yaml
import gradio as gr
from typing import List, Optional, Tuple
from functools import partial
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@ -10,6 +13,97 @@ logging.basicConfig(
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
GRADIO_CSS_STYLE = """
.json-container {
height: 95vh !important;
overflow-y: auto !important;
}
.chatbot {
height: calc(95vh - 100px) !important;
overflow-y: auto !important;
}
footer {visibility: hidden}
"""
def chat(
query: Optional[str],
conversation: Optional[List[Tuple[str, str]]],
history: List[dict],
client,
):
history.append({"role": "user", "content": query})
try:
response = client.chat.completions.create(
# we select model from arch_config file
model="--",
messages=history,
temperature=1.0,
stream=True,
)
except Exception as e:
# remove last user message in case of exception
history.pop()
log.info("Error calling gateway API: {}".format(e))
raise gr.Error("Error calling gateway API: {}".format(e))
conversation.append((query, ""))
for chunk in response:
tokens = process_stream_chunk(chunk, history)
if tokens:
conversation[-1] = (
conversation[-1][0],
conversation[-1][1] + tokens,
)
yield "", conversation, history
def create_gradio_app(demo_description, client):
with gr.Blocks(
theme=gr.themes.Default(
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "Arial", "sans-serif"]
),
fill_height=True,
css=GRADIO_CSS_STYLE,
) as demo:
with gr.Row(equal_height=True):
history = gr.State([])
with gr.Column(scale=1):
gr.Markdown(demo_description),
with gr.Accordion("Available Tools/APIs", open=True):
with gr.Column(scale=1):
gr.JSON(
value=get_prompt_targets(),
show_indices=False,
elem_classes="json-container",
min_height="80vh",
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Arch Chatbot",
elem_classes="chatbot",
)
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
autofocus=True,
elem_classes="textbox",
)
chat_with_client = partial(chat, client=client)
textbox.submit(
chat_with_client,
[textbox, chatbot, history],
[textbox, chatbot, history],
)
return demo
def process_stream_chunk(chunk, history): def process_stream_chunk(chunk, history):
delta = chunk.choices[0].delta delta = chunk.choices[0].delta