diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index 02d6e01c..05a6a6db 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -1,8 +1,11 @@ import json import os -from openai import OpenAI, DefaultHttpxClient -import gradio as gr import logging +import yaml +import gradio as gr + +from typing import List, Optional, Tuple +from openai import OpenAI, DefaultHttpxClient from dotenv import load_dotenv load_dotenv() @@ -15,9 +18,22 @@ logging.basicConfig( log = logging.getLogger(__name__) CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT") -ARCH_STATE_HEADER = "x-arch-state" log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}") +ARCH_STATE_HEADER = "x-arch-state" + +CSS_STYLE = """ +.json-container { + height: 95vh !important; + overflow-y: auto !important; +} +.chatbot { + height: calc(95vh - 100px) !important; + overflow-y: auto !important; +} +footer {visibility: hidden} +""" + client = OpenAI( api_key="--", base_url=CHAT_COMPLETION_ENDPOINT, @@ -25,11 +41,56 @@ client = OpenAI( ) -def predict(message, state): +def convert_prompt_target_to_openai_format(target): + tool = { + "description": target["description"], + "parameters": {"type": "object", "properties": {}, "required": []}, + } + + if "parameters" in target: + for param_info in target["parameters"]: + parameter = { + "type": param_info["type"], + "description": param_info["description"], + } + + for key in ["default", "format", "enum", "items", "minimum", "maximum"]: + if key in param_info: + parameter[key] = param_info[key] + + tool["parameters"]["properties"][param_info["name"]] = parameter + + required = param_info.get("required", False) + if required: + tool["parameters"]["required"].append(param_info["name"]) + + return {"name": target["name"], "info": tool} + + +def get_prompt_targets(): + try: + with open("arch_config.yaml", "r") as file: + config = yaml.safe_load(file) + + available_tools = [] + for target in config["prompt_targets"]: + if not target.get("default", False): + available_tools.append( + convert_prompt_target_to_openai_format(target) + ) + + return {tool["name"]: tool["info"] for tool in available_tools} + except Exception as e: + log.info(e) + return None + + +def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state): if "history" not in state: state["history"] = [] + history = state.get("history") - history.append({"role": "user", "content": message}) + history.append({"role": "user", "content": query}) log.info(f"history: {history}") # Custom headers @@ -58,7 +119,8 @@ def predict(message, state): # extract arch_state from metadata and store it in gradio session state # this state must be passed back to the gateway in the next request response_json = json.loads(raw_response.text) - if response_json: + log.info(response_json) + if response_json and "metadata" in response_json: # load arch_state from metadata arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}") # parse arch_state into json object @@ -78,25 +140,53 @@ def predict(message, state): # for gradio UI we don't want to show raw tool calls and messages from developer application # so we're filtering those out history_view = [h for h in history if h["role"] != "tool" and "content" in h] + messages = [ (history_view[i]["content"], history_view[i + 1]["content"]) for i in range(0, len(history_view) - 1, 2) ] - return messages, state + + return "", messages, state -with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo: - print("Starting Demo...") - chatbot = gr.Chatbot(label="Arch Chatbot", scale=1) - state = gr.State({}) - with gr.Row(): - txt = gr.Textbox( - show_label=False, - placeholder="Enter text and press enter", - scale=1, - autofocus=True, - ) +def main(): + with gr.Blocks( + theme=gr.themes.Default( + font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "Arial", "sans-serif"] + ), + fill_height=True, + css=CSS_STYLE, + ) as demo: + with gr.Row(equal_height=True): + state = gr.State({}) - txt.submit(predict, [txt, state], [chatbot, state]) + with gr.Column(scale=4): + gr.JSON( + value=get_prompt_targets(), + open=True, + show_indices=False, + label="Available Tools", + scale=1, + min_height="95vh", + elem_classes="json-container", + ) + with gr.Column(scale=6): + chatbot = gr.Chatbot( + label="Arch Chatbot", + scale=1, + elem_classes="chatbot", + ) + textbox = gr.Textbox( + show_label=False, + placeholder="Enter text and press enter", + scale=1, + autofocus=True, + ) -demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True, debug=True) + textbox.submit(chat, [textbox, chatbot, state], [textbox, chatbot, state]) + + demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True, debug=True) + + +if __name__ == "__main__": + main() diff --git a/chatbot_ui/requirements.txt b/chatbot_ui/requirements.txt index 60a107fe..b8e20cba 100644 --- a/chatbot_ui/requirements.txt +++ b/chatbot_ui/requirements.txt @@ -1,4 +1,4 @@ -gradio==4.43.0 +gradio==5.3.0 async_timeout==4.0.3 loguru==0.7.2 asyncio==3.4.3 diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index f4861dd7..81df31f8 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -25,3 +25,4 @@ pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal"; pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream"; pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener"; pub const ARCH_MODEL_PREFIX: &str = "Arch"; +pub const HALLUCINATION_TEMPLATE: &str = "It seems I’m missing some information. Could you provide the following details "; diff --git a/crates/prompt_gateway/src/hallucination.rs b/crates/prompt_gateway/src/hallucination.rs index 62b119ac..c4425957 100644 --- a/crates/prompt_gateway/src/hallucination.rs +++ b/crates/prompt_gateway/src/hallucination.rs @@ -1,6 +1,6 @@ use common::{ common_types::open_ai::Message, - consts::{ARCH_MODEL_PREFIX, ASSISTANT_ROLE, USER_ROLE}, + consts::{ARCH_MODEL_PREFIX, USER_ROLE, HALLUCINATION_TEMPLATE}, }; pub fn extract_messages_for_hallucination(messages: &Vec) -> Vec { @@ -18,9 +18,11 @@ pub fn extract_messages_for_hallucination(messages: &Vec) -> Vec