diff --git a/chatbot_ui/Dockerfile b/chatbot_ui/Dockerfile index c7ebf504..32101bf5 100644 --- a/chatbot_ui/Dockerfile +++ b/chatbot_ui/Dockerfile @@ -13,6 +13,6 @@ FROM python:3.10-slim AS output COPY --from=builder /runtime /usr/local WORKDIR /app -COPY /run_stream.py . +COPY *.py . CMD ["python", "run_stream.py"] diff --git a/chatbot_ui/arch_util.py b/chatbot_ui/arch_util.py deleted file mode 100644 index 567640e5..00000000 --- a/chatbot_ui/arch_util.py +++ /dev/null @@ -1,20 +0,0 @@ -import json - - -ARCH_STATE_HEADER = "x-arch-state" - - -def get_arch_messages(response_json): - arch_messages = [] - if response_json and "metadata" in response_json: - # load arch_state from metadata - arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}") - # parse arch_state into json object - arch_state = json.loads(arch_state_str) - # load messages from arch_state - arch_messages_str = arch_state.get("messages", "[]") - # parse messages into json object - arch_messages = json.loads(arch_messages_str) - # append messages from arch gateway to history - return arch_messages - return [] diff --git a/chatbot_ui/common.py b/chatbot_ui/common.py new file mode 100644 index 00000000..19111668 --- /dev/null +++ b/chatbot_ui/common.py @@ -0,0 +1,72 @@ +import json +import logging +import yaml + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) + +log = logging.getLogger(__name__) + +ARCH_STATE_HEADER = "x-arch-state" + + +def get_arch_messages(response_json): + arch_messages = [] + if response_json and "metadata" in response_json: + # load arch_state from metadata + arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}") + # parse arch_state into json object + arch_state = json.loads(arch_state_str) + # load messages from arch_state + arch_messages_str = arch_state.get("messages", "[]") + # parse messages into json object + arch_messages = json.loads(arch_messages_str) + # append messages from arch gateway to history + return arch_messages + return [] + + +def convert_prompt_target_to_openai_format(target): + tool = { + "description": target["description"], + "parameters": {"type": "object", "properties": {}, "required": []}, + } + + if "parameters" in target: + for param_info in target["parameters"]: + parameter = { + "type": param_info["type"], + "description": param_info["description"], + } + + for key in ["default", "format", "enum", "items", "minimum", "maximum"]: + if key in param_info: + parameter[key] = param_info[key] + + tool["parameters"]["properties"][param_info["name"]] = parameter + + required = param_info.get("required", False) + if required: + tool["parameters"]["required"].append(param_info["name"]) + + return {"name": target["name"], "info": tool} + + +def get_prompt_targets(): + try: + with open(os.getenv("ARCH_CONFIG", "arch_config.yaml"), "r") as file: + config = yaml.safe_load(file) + + available_tools = [] + for target in config["prompt_targets"]: + if not target.get("default", False): + available_tools.append( + convert_prompt_target_to_openai_format(target) + ) + + return {tool["name"]: tool["info"] for tool in available_tools} + except Exception as e: + log.info(e) + return None diff --git a/chatbot_ui/run.py b/chatbot_ui/run.py index 2c1ca620..7f8e2227 100644 --- a/chatbot_ui/run.py +++ b/chatbot_ui/run.py @@ -2,7 +2,7 @@ import json import os import logging import yaml -from arch_util import get_arch_messages +from common import get_arch_messages import gradio as gr from typing import List, Optional, Tuple diff --git a/chatbot_ui/run_stream.py b/chatbot_ui/run_stream.py index c9f799f0..798a2d5e 100644 --- a/chatbot_ui/run_stream.py +++ b/chatbot_ui/run_stream.py @@ -8,6 +8,8 @@ from typing import List, Optional, Tuple from openai import OpenAI from dotenv import load_dotenv +from common import get_prompt_targets + load_dotenv() @@ -40,50 +42,6 @@ client = OpenAI( ) -def convert_prompt_target_to_openai_format(target): - tool = { - "description": target["description"], - "parameters": {"type": "object", "properties": {}, "required": []}, - } - - if "parameters" in target: - for param_info in target["parameters"]: - parameter = { - "type": param_info["type"], - "description": param_info["description"], - } - - for key in ["default", "format", "enum", "items", "minimum", "maximum"]: - if key in param_info: - parameter[key] = param_info[key] - - tool["parameters"]["properties"][param_info["name"]] = parameter - - required = param_info.get("required", False) - if required: - tool["parameters"]["required"].append(param_info["name"]) - - return {"name": target["name"], "info": tool} - - -def get_prompt_targets(): - try: - with open(os.getenv("ARCH_CONFIG", "arch_config.yaml"), "r") as file: - config = yaml.safe_load(file) - - available_tools = [] - for target in config["prompt_targets"]: - if not target.get("default", False): - available_tools.append( - convert_prompt_target_to_openai_format(target) - ) - - return {tool["name"]: tool["info"] for tool in available_tools} - except Exception as e: - log.info(e) - return None - - def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state): if "history" not in state: state["history"] = [] @@ -105,7 +63,6 @@ def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], st log.info("Error calling gateway API: {}".format(e)) raise gr.Error("Error calling gateway API: {}".format(e)) - history.append({"role": "assistant", "content": "", "model": ""}) conversation.append((query, "")) for chunk in response: