plano/demos/shared/chatbot_ui/run_stream.py
2025-04-15 14:39:12 -07:00

160 lines
4.7 KiB
Python

import json
import os
import logging
import yaml
import gradio as gr
from typing import List, Optional, Tuple
from openai import OpenAI
from dotenv import load_dotenv
from common import format_log, get_llm_models, get_prompt_targets, process_stream_chunk
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
log = logging.getLogger(__name__)
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}")
CSS_STYLE = """
.json-container {
height: 95vh !important;
overflow-y: auto !important;
}
.chatbot {
height: calc(95vh - 100px) !important;
overflow-y: auto !important;
}
footer {visibility: hidden}
"""
def chat(
query: Optional[str],
conversation: Optional[List[Tuple[str, str]]],
history: List[dict],
debug_output: str,
model_selector: str,
):
history.append({"role": "user", "content": query})
if debug_output is None:
debug_output = ""
try:
headers = {}
if model_selector and model_selector != "":
headers["x-arch-llm-provider-hint"] = model_selector
client = OpenAI(
api_key="None",
base_url=CHAT_COMPLETION_ENDPOINT,
default_headers=headers,
)
response = client.chat.completions.create(
# we select model from arch_config file
model="None",
messages=history,
temperature=1.0,
stream=True,
)
except Exception as e:
# remove last user message in case of exception
history.pop()
log.info("Error calling gateway API: {}".format(e))
raise gr.Error("Error calling gateway API: {}".format(e))
conversation.append((query, ""))
model_is_set = False
for chunk in response:
tokens = process_stream_chunk(chunk, history)
if tokens and not model_is_set:
model_is_set = True
model = history[-1]["model"]
debug_output = debug_output + "\n" + format_log(f"model: {model}")
if tokens:
conversation[-1] = (
conversation[-1][0],
conversation[-1][1] + tokens,
)
yield "", conversation, history, debug_output, model_selector
# update assistant response to have correct format
# arch-fc 1.1 expects following format:
# {
# "response": "<assistant response>",
# }
# and this entire block needs to be encoded in ```json\n{json_encoded_content}\n```
if not history[-1]["model"].startswith("Arch"):
assistant_response = {
"response": history[-1]["content"],
}
history[-1]["content"] = "```json\n{}\n```".format(
json.dumps(assistant_response)
)
log.info("history: {}".format(json.dumps(history)))
def main():
with gr.Blocks(
theme=gr.themes.Default(
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "Arial", "sans-serif"]
),
fill_height=True,
css=CSS_STYLE,
) as demo:
with gr.Row(equal_height=True):
history = gr.State([])
with gr.Column(scale=1):
with gr.Accordion("See available tools", open=False):
with gr.Column(scale=1):
gr.JSON(
value=get_prompt_targets(),
show_indices=False,
elem_classes="json-container",
min_height="50vh",
)
model_selector_textbox = gr.Dropdown(
get_llm_models(),
label="override model",
elem_classes="dropdown",
)
debug_output = gr.TextArea(
label="debug output",
elem_classes="debug_output",
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Arch Chatbot",
elem_classes="chatbot",
)
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
autofocus=True,
elem_classes="textbox",
)
textbox.submit(
chat,
[textbox, chatbot, history, debug_output, model_selector_textbox],
[textbox, chatbot, history, debug_output, model_selector_textbox],
)
demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True, debug=True)
if __name__ == "__main__":
main()