plano/chatbot_ui/run_stream.py

121 lines
3.1 KiB
Python
Raw Permalink Normal View History

import json
import os
import logging
import yaml
import gradio as gr
from typing import List, Optional, Tuple
from openai import OpenAI
2024-09-19 17:48:50 -07:00
from dotenv import load_dotenv
2024-10-29 11:15:07 -07:00
from common import get_prompt_targets, process_stream_chunk
2024-10-28 23:30:09 -07:00
2024-09-19 17:48:50 -07:00
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
log = logging.getLogger(__name__)
2024-09-19 17:48:50 -07:00
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}")
CSS_STYLE = """
.json-container {
2024-10-29 10:24:48 -07:00
height: 95vh !important;
overflow-y: auto !important;
}
.chatbot {
2024-10-29 10:24:48 -07:00
height: calc(95vh - 100px) !important;
overflow-y: auto !important;
}
footer {visibility: hidden}
"""
client = OpenAI(
2024-10-09 15:47:32 -07:00
api_key="--",
base_url=CHAT_COMPLETION_ENDPOINT,
)
2024-10-29 11:15:07 -07:00
def chat(
query: Optional[str],
conversation: Optional[List[Tuple[str, str]]],
history: List[dict],
):
history.append({"role": "user", "content": query})
try:
2024-10-28 23:22:48 -07:00
response = client.chat.completions.create(
# we select model from arch_config file
2024-10-09 15:47:32 -07:00
model="--",
messages=history,
temperature=1.0,
2024-10-28 23:22:48 -07:00
stream=True,
)
except Exception as e:
# remove last user message in case of exception
history.pop()
log.info("Error calling gateway API: {}".format(e))
raise gr.Error("Error calling gateway API: {}".format(e))
2024-10-28 23:22:48 -07:00
conversation.append((query, ""))
for chunk in response:
2024-10-29 11:15:07 -07:00
tokens = process_stream_chunk(chunk, history)
if tokens:
2024-10-28 23:22:48 -07:00
conversation[-1] = (
conversation[-1][0],
2024-10-29 11:15:07 -07:00
conversation[-1][1] + tokens,
2024-10-28 23:22:48 -07:00
)
2024-10-29 11:15:07 -07:00
yield "", conversation, history
def main():
with gr.Blocks(
theme=gr.themes.Default(
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "Arial", "sans-serif"]
),
fill_height=True,
css=CSS_STYLE,
) as demo:
2024-10-29 10:24:48 -07:00
with gr.Row(equal_height=True):
2024-10-29 11:19:22 -07:00
history = gr.State([])
2024-10-29 10:24:48 -07:00
with gr.Column(scale=1):
with gr.Accordion("See available tools", open=False):
with gr.Column(scale=1):
gr.JSON(
value=get_prompt_targets(),
show_indices=False,
elem_classes="json-container",
min_height="95vh",
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Arch Chatbot",
elem_classes="chatbot",
)
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
autofocus=True,
2024-10-29 10:24:48 -07:00
elem_classes="textbox",
)
2024-10-29 11:19:22 -07:00
textbox.submit(
chat, [textbox, chatbot, history], [textbox, chatbot, history]
)
demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True, debug=True)
if __name__ == "__main__":
main()