import json import os import logging import yaml import gradio as gr from typing import List, Optional, Tuple from openai import OpenAI from dotenv import load_dotenv from common import format_log, get_llm_models, get_prompt_targets, process_stream_chunk load_dotenv() logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", ) log = logging.getLogger(__name__) CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT") log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}") CSS_STYLE = """ .json-container { height: 95vh !important; overflow-y: auto !important; } .chatbot { height: calc(95vh - 100px) !important; overflow-y: auto !important; } footer {visibility: hidden} """ def chat( query: Optional[str], conversation: Optional[List[Tuple[str, str]]], history: List[dict], debug_output: str, model_selector: str, ): history.append({"role": "user", "content": query}) if debug_output is None: debug_output = "" try: headers = {} if model_selector and model_selector != "": headers["x-arch-llm-provider-hint"] = model_selector client = OpenAI( api_key="None", base_url=CHAT_COMPLETION_ENDPOINT, default_headers=headers, ) response = client.chat.completions.create( # we select model from arch_config file model="None", messages=history, temperature=1.0, stream=True, ) except Exception as e: # remove last user message in case of exception history.pop() log.info("Error calling gateway API: {}".format(e)) raise gr.Error("Error calling gateway API: {}".format(e)) conversation.append((query, "")) model_is_set = False for chunk in response: tokens = process_stream_chunk(chunk, history) if tokens and not model_is_set: model_is_set = True model = history[-1]["model"] debug_output = debug_output + "\n" + format_log(f"model: {model}") if tokens: conversation[-1] = ( conversation[-1][0], conversation[-1][1] + tokens, ) yield "", conversation, history, debug_output, model_selector # update assistant response to have correct format # arch-fc 1.1 expects following format: # { # "response": "", # } # and this entire block needs to be encoded in ```json\n{json_encoded_content}\n``` if not history[-1]["model"].startswith("Arch"): assistant_response = { "response": history[-1]["content"], } history[-1]["content"] = "```json\n{}\n```".format( json.dumps(assistant_response) ) log.info("history: {}".format(json.dumps(history))) def main(): with gr.Blocks( theme=gr.themes.Default( font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "Arial", "sans-serif"] ), fill_height=True, css=CSS_STYLE, ) as demo: with gr.Row(equal_height=True): history = gr.State([]) with gr.Column(scale=1): with gr.Accordion("See available tools", open=False): with gr.Column(scale=1): gr.JSON( value=get_prompt_targets(), show_indices=False, elem_classes="json-container", min_height="50vh", ) model_selector_textbox = gr.Dropdown( get_llm_models(), label="override model", elem_classes="dropdown", ) debug_output = gr.TextArea( label="debug output", elem_classes="debug_output", ) with gr.Column(scale=2): chatbot = gr.Chatbot( label="Arch Chatbot", elem_classes="chatbot", ) textbox = gr.Textbox( show_label=False, placeholder="Enter text and press enter", autofocus=True, elem_classes="textbox", ) textbox.submit( chat, [textbox, chatbot, history, debug_output, model_selector_textbox], [textbox, chatbot, history, debug_output, model_selector_textbox], ) demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True, debug=True) if __name__ == "__main__": main()