diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 4bbd3fa6..a2fe0b2e 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -834,7 +834,7 @@ impl StreamContext { ); debug!( - "archgw => api call, endpoint: {}/{}, body: {}", + "archgw => api call, endpoint: {}{}, body: {}", endpoint.name.as_str(), path, tool_params_json_str diff --git a/demos/hr_agent/Dockerfile b/demos/hr_agent/Dockerfile index c97fb497..427fe8a4 100644 --- a/demos/hr_agent/Dockerfile +++ b/demos/hr_agent/Dockerfile @@ -5,16 +5,12 @@ FROM base AS builder WORKDIR /src COPY requirements.txt /src/ -COPY workforce_data.json /src/ RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt -COPY . /src - FROM python:3.10-slim AS output - COPY --from=builder /runtime /usr/local -COPY . /app WORKDIR /app +COPY . /app CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--log-level", "info"] diff --git a/demos/hr_agent/arch_config.yaml b/demos/hr_agent/arch_config.yaml index 71baa08e..b926f7ad 100644 --- a/demos/hr_agent/arch_config.yaml +++ b/demos/hr_agent/arch_config.yaml @@ -27,12 +27,6 @@ system_prompt: | You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workfoce planning. NOTHING ELSE. When you get data in json format, offer some summary but don't be too verbose. prompt_targets: - - name: hr_qa - endpoint: - name: app_server - path: /agent/hr_qa - description: Handle general Q/A related to HR. - default: true - name: workforce description: Get workforce data like headcount and satisfacton levels by region and staffing type endpoint: @@ -47,10 +41,10 @@ prompt_targets: type: str required: true description: Geographical region for which you want workforce data like asia, europe, americas. - - name: point_in_time + - name: data_snapshot_days_ago type: int required: false - description: the point in time for which to retrieve data. For e.g 0 days ago, 30 days ago, etc. + description: the snapshot day for which you want workforce data. - name: slack_message endpoint: name: app_server diff --git a/demos/hr_agent/docker-compose.yaml b/demos/hr_agent/docker-compose.yaml index 1067ba6c..449668b2 100644 --- a/demos/hr_agent/docker-compose.yaml +++ b/demos/hr_agent/docker-compose.yaml @@ -4,22 +4,14 @@ services: context: . environment: - SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN:-None} + - OPENAI_API_KEY=${OPENAI_API_KEY:?error} + - CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1 + volumes: + - ./arch_config.yaml:/app/arch_config.yaml + - ../shared/chatbot_ui/common.py:/app/common.py ports: - "18083:80" healthcheck: test: ["CMD", "curl" ,"http://localhost:80/healthz"] interval: 5s retries: 20 - - chatbot_ui: - build: - context: ../shared/chatbot_ui - ports: - - "18080:8080" - environment: - - OPENAI_API_KEY=${OPENAI_API_KEY:?error} - - CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1 - extra_hosts: - - "host.docker.internal:host-gateway" - volumes: - - ./arch_config.yaml:/app/arch_config.yaml diff --git a/demos/hr_agent/main.py b/demos/hr_agent/main.py index 3a1a14e6..9c49d120 100644 --- a/demos/hr_agent/main.py +++ b/demos/hr_agent/main.py @@ -1,20 +1,31 @@ import os import json import pandas as pd +import gradio as gr +import logging + from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field -from typing import Optional from enum import Enum +from typing import List, Optional, Tuple from slack_sdk import WebClient from slack_sdk.errors import SlackApiError +from openai import OpenAI +from common import create_gradio_app app = FastAPI() workforce_data_df = None +demo_description = """This demo showcases how the **Arch** can be used to build an +HR agent to manage workforce-related inquiries, workforce planning, and communication via Slack. +It intelligently routes incoming prompts to the correct targets, providing concise and useful responses +tailored for HR and workforce decision-making. """ with open("workforce_data.json") as file: workforce_data = json.load(file) workforce_data_df = pd.json_normalize( - workforce_data, record_path=["regions"], meta=["point_in_time", "satisfaction"] + workforce_data, + record_path=["regions"], + meta=["data_snapshot_days_ago", "satisfaction"], ) @@ -22,7 +33,7 @@ with open("workforce_data.json") as file: class WorkforceRequset(BaseModel): region: str staffing_type: str - point_in_time: Optional[int] = None + data_snapshot_days_ago: Optional[int] = None class SlackRequest(BaseModel): @@ -36,25 +47,6 @@ class WorkforceResponse(BaseModel): satisfaction: float -# Post method for device summary -@app.post("/agent/workforce") -def get_workforce(request: WorkforceRequset): - """ - Endpoint to workforce data by region, staffing type at a given point in time. - """ - region = request.region.lower() - staffing_type = request.staffing_type.lower() - point_in_time = request.point_in_time if request.point_in_time else 0 - - response = { - "region": region, - "staffing_type": f"Staffing agency: {staffing_type}", - "headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['point_in_time']==point_in_time)][staffing_type].values[0])}", - "satisfaction": f"Satisifaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['point_in_time']==point_in_time)]['satisfaction'].values[0])}", - } - return response - - @app.post("/agent/slack_message") def send_slack_message(request: SlackRequest): """ @@ -80,27 +72,38 @@ def send_slack_message(request: SlackRequest): print(f"Error sending message: {e.response['error']}") -@app.post("/agent/hr_qa") -async def general_hr_qa(): +# Post method for device summary +@app.post("/agent/workforce") +def get_workforce(request: WorkforceRequset): """ - This method handles Q/A related to general issues in HR. - It forwards the conversation to the OpenAI client via a local proxy and returns the response. + Endpoint to workforce data by region, staffing type at a given point in time. """ - return { - "choices": [ - { - "message": { - "role": "assistant", - "content": "I am a helpful HR agent, and I can help you plan for workforce related questions", - }, - "finish_reason": "completed", - "index": 0, - } - ], - "model": "hr_agent", - "usage": {"completion_tokens": 0}, - } + region = request.region.lower() + staffing_type = request.staffing_type.lower() + data_snapshot_days_ago = ( + request.data_snapshot_days_ago + if request.data_snapshot_days_ago + else 0 # this param is not required. + ) + response = { + "region": region, + "staffing_type": f"Staffing agency: {staffing_type}", + "headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)][staffing_type].values[0])}", + "satisfaction": f"Satisifaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)]['satisfaction'].values[0])}", + } + return response + + +CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT") +client = OpenAI( + api_key="--", + base_url=CHAT_COMPLETION_ENDPOINT, +) + +gr.mount_gradio_app( + app, create_gradio_app(demo_description, client), path="/agent/chat" +) if __name__ == "__main__": app.run(debug=True) diff --git a/demos/hr_agent/requirements.txt b/demos/hr_agent/requirements.txt index 3068f77e..9a108c37 100644 --- a/demos/hr_agent/requirements.txt +++ b/demos/hr_agent/requirements.txt @@ -1,6 +1,13 @@ fastapi uvicorn -pydantic slack-sdk typing pandas +gradio==5.3.0 +async_timeout==4.0.3 +loguru==0.7.2 +asyncio==3.4.3 +httpx==0.27.0 +python-dotenv==1.0.1 +pydantic==2.8.2 +openai==1.51.0 diff --git a/demos/hr_agent/workforce_data.json b/demos/hr_agent/workforce_data.json index e28654d7..4f2279ba 100644 --- a/demos/hr_agent/workforce_data.json +++ b/demos/hr_agent/workforce_data.json @@ -1,6 +1,6 @@ [ { - "point_in_time": 0, + "data_snapshot_days_ago": 0, "regions": [ { "region": "asia", "contract": 100, "fte": 150, "agency": 2000 }, { "region": "europe", "contract": 80, "fte": 120, "agency": 2500 }, @@ -9,7 +9,7 @@ "satisfaction": 3.5 }, { - "point_in_time": 30, + "data_snapshot_days_ago": 30, "regions": [ { "region": "asia", "contract": 110, "fte": 155, "agency": 1000 }, { "region": "europe", "contract": 85, "fte": 130, "agency": 1600 }, @@ -18,7 +18,7 @@ "satisfaction": 4.0 }, { - "point_in_time": 60, + "data_snapshot_days_ago": 60, "regions": [ { "region": "asia", "contract": 115, "fte": 160, "agency": 500 }, { "region": "europe", "contract": 90, "fte": 140, "agency": 700 }, diff --git a/demos/shared/chatbot_ui/common.py b/demos/shared/chatbot_ui/common.py index 3fd5c265..4401e266 100644 --- a/demos/shared/chatbot_ui/common.py +++ b/demos/shared/chatbot_ui/common.py @@ -2,6 +2,9 @@ import json import logging import os import yaml +import gradio as gr +from typing import List, Optional, Tuple +from functools import partial logging.basicConfig( level=logging.INFO, @@ -10,6 +13,97 @@ logging.basicConfig( log = logging.getLogger(__name__) +GRADIO_CSS_STYLE = """ +.json-container { + height: 95vh !important; + overflow-y: auto !important; +} +.chatbot { + height: calc(95vh - 100px) !important; + overflow-y: auto !important; +} +footer {visibility: hidden} +""" + + +def chat( + query: Optional[str], + conversation: Optional[List[Tuple[str, str]]], + history: List[dict], + client, +): + history.append({"role": "user", "content": query}) + + try: + response = client.chat.completions.create( + # we select model from arch_config file + model="--", + messages=history, + temperature=1.0, + stream=True, + ) + except Exception as e: + # remove last user message in case of exception + history.pop() + log.info("Error calling gateway API: {}".format(e)) + raise gr.Error("Error calling gateway API: {}".format(e)) + + conversation.append((query, "")) + + for chunk in response: + tokens = process_stream_chunk(chunk, history) + if tokens: + conversation[-1] = ( + conversation[-1][0], + conversation[-1][1] + tokens, + ) + + yield "", conversation, history + + +def create_gradio_app(demo_description, client): + with gr.Blocks( + theme=gr.themes.Default( + font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "Arial", "sans-serif"] + ), + fill_height=True, + css=GRADIO_CSS_STYLE, + ) as demo: + with gr.Row(equal_height=True): + history = gr.State([]) + + with gr.Column(scale=1): + gr.Markdown(demo_description), + with gr.Accordion("Available Tools/APIs", open=True): + with gr.Column(scale=1): + gr.JSON( + value=get_prompt_targets(), + show_indices=False, + elem_classes="json-container", + min_height="80vh", + ) + + with gr.Column(scale=2): + chatbot = gr.Chatbot( + label="Arch Chatbot", + elem_classes="chatbot", + ) + textbox = gr.Textbox( + show_label=False, + placeholder="Enter text and press enter", + autofocus=True, + elem_classes="textbox", + ) + chat_with_client = partial(chat, client=client) + + textbox.submit( + chat_with_client, + [textbox, chatbot, history], + [textbox, chatbot, history], + ) + + return demo + def process_stream_chunk(chunk, history): delta = chunk.choices[0].delta