Update chatbot UI and update hallucination check (#218)

* update chatbot UI

* Update docker-compose for demos

* Fix bugs

* fix for emtadata (#219)

* fix for emtadata

* fix

* revert

* merge main

---------

Co-authored-by: CTran <cotran2@utexas.edu>
This commit is contained in:
Shuguang Chen 2024-10-24 14:11:53 -07:00 committed by GitHub
parent 05f0491f76
commit 5f3aff4922
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 129 additions and 34 deletions

View file

@ -1,8 +1,11 @@
import json
import os
from openai import OpenAI, DefaultHttpxClient
import gradio as gr
import logging
import yaml
import gradio as gr
from typing import List, Optional, Tuple
from openai import OpenAI, DefaultHttpxClient
from dotenv import load_dotenv
load_dotenv()
@ -15,9 +18,22 @@ logging.basicConfig(
log = logging.getLogger(__name__)
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
ARCH_STATE_HEADER = "x-arch-state"
log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}")
ARCH_STATE_HEADER = "x-arch-state"
CSS_STYLE = """
.json-container {
height: 95vh !important;
overflow-y: auto !important;
}
.chatbot {
height: calc(95vh - 100px) !important;
overflow-y: auto !important;
}
footer {visibility: hidden}
"""
client = OpenAI(
api_key="--",
base_url=CHAT_COMPLETION_ENDPOINT,
@ -25,11 +41,56 @@ client = OpenAI(
)
def predict(message, state):
def convert_prompt_target_to_openai_format(target):
tool = {
"description": target["description"],
"parameters": {"type": "object", "properties": {}, "required": []},
}
if "parameters" in target:
for param_info in target["parameters"]:
parameter = {
"type": param_info["type"],
"description": param_info["description"],
}
for key in ["default", "format", "enum", "items", "minimum", "maximum"]:
if key in param_info:
parameter[key] = param_info[key]
tool["parameters"]["properties"][param_info["name"]] = parameter
required = param_info.get("required", False)
if required:
tool["parameters"]["required"].append(param_info["name"])
return {"name": target["name"], "info": tool}
def get_prompt_targets():
try:
with open("arch_config.yaml", "r") as file:
config = yaml.safe_load(file)
available_tools = []
for target in config["prompt_targets"]:
if not target.get("default", False):
available_tools.append(
convert_prompt_target_to_openai_format(target)
)
return {tool["name"]: tool["info"] for tool in available_tools}
except Exception as e:
log.info(e)
return None
def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state):
if "history" not in state:
state["history"] = []
history = state.get("history")
history.append({"role": "user", "content": message})
history.append({"role": "user", "content": query})
log.info(f"history: {history}")
# Custom headers
@ -58,7 +119,8 @@ def predict(message, state):
# extract arch_state from metadata and store it in gradio session state
# this state must be passed back to the gateway in the next request
response_json = json.loads(raw_response.text)
if response_json:
log.info(response_json)
if response_json and "metadata" in response_json:
# load arch_state from metadata
arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}")
# parse arch_state into json object
@ -78,25 +140,53 @@ def predict(message, state):
# for gradio UI we don't want to show raw tool calls and messages from developer application
# so we're filtering those out
history_view = [h for h in history if h["role"] != "tool" and "content" in h]
messages = [
(history_view[i]["content"], history_view[i + 1]["content"])
for i in range(0, len(history_view) - 1, 2)
]
return messages, state
return "", messages, state
with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo:
print("Starting Demo...")
chatbot = gr.Chatbot(label="Arch Chatbot", scale=1)
state = gr.State({})
with gr.Row():
txt = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
scale=1,
autofocus=True,
)
def main():
with gr.Blocks(
theme=gr.themes.Default(
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "Arial", "sans-serif"]
),
fill_height=True,
css=CSS_STYLE,
) as demo:
with gr.Row(equal_height=True):
state = gr.State({})
txt.submit(predict, [txt, state], [chatbot, state])
with gr.Column(scale=4):
gr.JSON(
value=get_prompt_targets(),
open=True,
show_indices=False,
label="Available Tools",
scale=1,
min_height="95vh",
elem_classes="json-container",
)
with gr.Column(scale=6):
chatbot = gr.Chatbot(
label="Arch Chatbot",
scale=1,
elem_classes="chatbot",
)
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
scale=1,
autofocus=True,
)
demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True, debug=True)
textbox.submit(chat, [textbox, chatbot, state], [textbox, chatbot, state])
demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True, debug=True)
if __name__ == "__main__":
main()