plano/chatbot_ui/run_stream.py

137 lines
3.7 KiB
Python
Raw Normal View History

import json
import os
import logging
import yaml
import gradio as gr
from typing import List, Optional, Tuple
from openai import OpenAI
2024-09-19 17:48:50 -07:00
from dotenv import load_dotenv
2024-10-28 23:30:09 -07:00
from common import get_prompt_targets
2024-09-19 17:48:50 -07:00
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
log = logging.getLogger(__name__)
2024-09-19 17:48:50 -07:00
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}")
CSS_STYLE = """
.json-container {
height: 95vh !important;
overflow-y: auto !important;
}
.chatbot {
height: calc(95vh - 100px) !important;
overflow-y: auto !important;
}
footer {visibility: hidden}
"""
client = OpenAI(
2024-10-09 15:47:32 -07:00
api_key="--",
base_url=CHAT_COMPLETION_ENDPOINT,
)
2024-10-28 21:16:53 -07:00
def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state):
if "history" not in state:
state["history"] = []
history = state.get("history")
history.append({"role": "user", "content": query})
try:
2024-10-28 23:22:48 -07:00
response = client.chat.completions.create(
# we select model from arch_config file
2024-10-09 15:47:32 -07:00
model="--",
messages=history,
temperature=1.0,
2024-10-28 23:22:48 -07:00
stream=True,
)
except Exception as e:
# remove last user message in case of exception
history.pop()
log.info("Error calling gateway API: {}".format(e))
raise gr.Error("Error calling gateway API: {}".format(e))
2024-10-28 23:22:48 -07:00
conversation.append((query, ""))
for chunk in response:
message = chunk.choices[0].delta
if message.role and message.role != history[-1]["role"]:
# create new history item if role changes
# this is likely due to arch tool call and api response
history.append(
{
"role": message.role,
}
)
history[-1]["model"] = chunk.model
if message.tool_calls:
history[-1]["tool_calls"] = message.tool_calls
if message.content:
history[-1]["content"] = history[-1].get("content", "") + message.content
2024-10-28 23:41:08 -07:00
# message.content is none for tool calls
# when "role = tool" content would contain api call response
if message.content and history[-1]["role"] != "tool":
2024-10-28 23:22:48 -07:00
conversation[-1] = (
conversation[-1][0],
conversation[-1][1] + message.content,
)
2024-10-28 21:16:53 -07:00
yield "", conversation, state
def main():
with gr.Blocks(
theme=gr.themes.Default(
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "Arial", "sans-serif"]
),
fill_height=True,
css=CSS_STYLE,
) as demo:
with gr.Row(equal_height=True):
state = gr.State({})
with gr.Column(scale=4):
gr.JSON(
value=get_prompt_targets(),
open=True,
show_indices=False,
label="Available Tools",
scale=1,
min_height="95vh",
elem_classes="json-container",
)
with gr.Column(scale=6):
chatbot = gr.Chatbot(
label="Arch Chatbot",
scale=1,
elem_classes="chatbot",
)
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
scale=1,
autofocus=True,
)
textbox.submit(chat, [textbox, chatbot, state], [textbox, chatbot, state])
demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True, debug=True)
if __name__ == "__main__":
main()