Update chatbot UI and update hallucination check (#218)

* update chatbot UI

* Update docker-compose for demos

* Fix bugs

* fix for emtadata (#219)

* fix for emtadata

* fix

* revert

* merge main

---------

Co-authored-by: CTran <cotran2@utexas.edu>
This commit is contained in:
Shuguang Chen 2024-10-24 14:11:53 -07:00 committed by GitHub
parent 05f0491f76
commit 5f3aff4922
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 129 additions and 34 deletions

View file

@ -1,8 +1,11 @@
import json
import os
from openai import OpenAI, DefaultHttpxClient
import gradio as gr
import logging
import yaml
import gradio as gr
from typing import List, Optional, Tuple
from openai import OpenAI, DefaultHttpxClient
from dotenv import load_dotenv
load_dotenv()
@ -15,9 +18,22 @@ logging.basicConfig(
log = logging.getLogger(__name__)
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
ARCH_STATE_HEADER = "x-arch-state"
log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}")
ARCH_STATE_HEADER = "x-arch-state"
CSS_STYLE = """
.json-container {
height: 95vh !important;
overflow-y: auto !important;
}
.chatbot {
height: calc(95vh - 100px) !important;
overflow-y: auto !important;
}
footer {visibility: hidden}
"""
client = OpenAI(
api_key="--",
base_url=CHAT_COMPLETION_ENDPOINT,
@ -25,11 +41,56 @@ client = OpenAI(
)
def predict(message, state):
def convert_prompt_target_to_openai_format(target):
tool = {
"description": target["description"],
"parameters": {"type": "object", "properties": {}, "required": []},
}
if "parameters" in target:
for param_info in target["parameters"]:
parameter = {
"type": param_info["type"],
"description": param_info["description"],
}
for key in ["default", "format", "enum", "items", "minimum", "maximum"]:
if key in param_info:
parameter[key] = param_info[key]
tool["parameters"]["properties"][param_info["name"]] = parameter
required = param_info.get("required", False)
if required:
tool["parameters"]["required"].append(param_info["name"])
return {"name": target["name"], "info": tool}
def get_prompt_targets():
try:
with open("arch_config.yaml", "r") as file:
config = yaml.safe_load(file)
available_tools = []
for target in config["prompt_targets"]:
if not target.get("default", False):
available_tools.append(
convert_prompt_target_to_openai_format(target)
)
return {tool["name"]: tool["info"] for tool in available_tools}
except Exception as e:
log.info(e)
return None
def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state):
if "history" not in state:
state["history"] = []
history = state.get("history")
history.append({"role": "user", "content": message})
history.append({"role": "user", "content": query})
log.info(f"history: {history}")
# Custom headers
@ -58,7 +119,8 @@ def predict(message, state):
# extract arch_state from metadata and store it in gradio session state
# this state must be passed back to the gateway in the next request
response_json = json.loads(raw_response.text)
if response_json:
log.info(response_json)
if response_json and "metadata" in response_json:
# load arch_state from metadata
arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}")
# parse arch_state into json object
@ -78,25 +140,53 @@ def predict(message, state):
# for gradio UI we don't want to show raw tool calls and messages from developer application
# so we're filtering those out
history_view = [h for h in history if h["role"] != "tool" and "content" in h]
messages = [
(history_view[i]["content"], history_view[i + 1]["content"])
for i in range(0, len(history_view) - 1, 2)
]
return messages, state
return "", messages, state
with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo:
print("Starting Demo...")
chatbot = gr.Chatbot(label="Arch Chatbot", scale=1)
state = gr.State({})
with gr.Row():
txt = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
scale=1,
autofocus=True,
)
def main():
with gr.Blocks(
theme=gr.themes.Default(
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "Arial", "sans-serif"]
),
fill_height=True,
css=CSS_STYLE,
) as demo:
with gr.Row(equal_height=True):
state = gr.State({})
txt.submit(predict, [txt, state], [chatbot, state])
with gr.Column(scale=4):
gr.JSON(
value=get_prompt_targets(),
open=True,
show_indices=False,
label="Available Tools",
scale=1,
min_height="95vh",
elem_classes="json-container",
)
with gr.Column(scale=6):
chatbot = gr.Chatbot(
label="Arch Chatbot",
scale=1,
elem_classes="chatbot",
)
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
scale=1,
autofocus=True,
)
demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True, debug=True)
textbox.submit(chat, [textbox, chatbot, state], [textbox, chatbot, state])
demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True, debug=True)
if __name__ == "__main__":
main()

View file

@ -1,4 +1,4 @@
gradio==4.43.0
gradio==5.3.0
async_timeout==4.0.3
loguru==0.7.2
asyncio==3.4.3

View file

@ -25,3 +25,4 @@ pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";
pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream";
pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";
pub const ARCH_MODEL_PREFIX: &str = "Arch";
pub const HALLUCINATION_TEMPLATE: &str = "It seems Im missing some information. Could you provide the following details ";

View file

@ -1,6 +1,6 @@
use common::{
common_types::open_ai::Message,
consts::{ARCH_MODEL_PREFIX, ASSISTANT_ROLE, USER_ROLE},
consts::{ARCH_MODEL_PREFIX, USER_ROLE, HALLUCINATION_TEMPLATE},
};
pub fn extract_messages_for_hallucination(messages: &Vec<Message>) -> Vec<String> {
@ -18,9 +18,11 @@ pub fn extract_messages_for_hallucination(messages: &Vec<Message>) -> Vec<String
for message in messages.iter().rev() {
if let Some(model) = message.model.as_ref() {
if !model.starts_with(ARCH_MODEL_PREFIX) {
if message.role == ASSISTANT_ROLE {
break;
}
if let Some(content) = &message.content {
if !content.starts_with(HALLUCINATION_TEMPLATE) {
break;
}
}
}
}
if message.role == USER_ROLE {

View file

@ -12,12 +12,7 @@ use common::common_types::{
};
use common::configuration::{Overrides, PromptGuards, PromptTarget};
use common::consts::{
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
ARCH_INTERNAL_CLUSTER_NAME, MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER,
ARCH_UPSTREAM_HOST_HEADER, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST,
HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
ZEROSHOT_INTERNAL_HOST,
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_EMBEDDING_MODEL, HALLUCINATION_TEMPLATE, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST
};
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -328,12 +323,11 @@ impl StreamContext {
if !keys_with_low_score.is_empty() {
let response =
"It seems Im missing some information. Could you provide the following details: "
.to_string()
HALLUCINATION_TEMPLATE.to_string()
+ &keys_with_low_score.join(", ")
+ " ?";
let message = Message {
role: SYSTEM_ROLE.to_string(),
role: ASSISTANT_ROLE.to_string(),
content: Some(response),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,

View file

@ -20,6 +20,8 @@ services:
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1 #this is only because we are running the sample app in the same docker container environemtn as archgw
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- ./arch_config.yaml:/app/arch_config.yaml
opentelemetry:
build:

View file

@ -21,3 +21,5 @@ services:
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- ./arch_config.yaml:/app/arch_config.yaml

View file

@ -20,3 +20,5 @@ services:
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- ./arch_config.yaml:/app/arch_config.yaml

View file

@ -21,3 +21,5 @@ services:
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- ./arch_config.yaml:/app/arch_config.yaml