From 91df1e79412a5214ea2b8411bfeb89d00a13db94 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 29 Oct 2024 11:15:07 -0700 Subject: [PATCH] fix more --- chatbot_ui/common.py | 22 +++++++++++++++++++++ chatbot_ui/run_stream.py | 41 ++++++++++++---------------------------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/chatbot_ui/common.py b/chatbot_ui/common.py index 5fd33c5f..64619412 100644 --- a/chatbot_ui/common.py +++ b/chatbot_ui/common.py @@ -13,6 +13,28 @@ log = logging.getLogger(__name__) ARCH_STATE_HEADER = "x-arch-state" +def process_stream_chunk(chunk, history): + delta = chunk.choices[0].delta + if delta.role and delta.role != history[-1]["role"]: + # create new history item if role changes + # this is likely due to arch tool call and api response + history.append({"role": delta.role}) + + history[-1]["model"] = chunk.model + # append tool calls to history if there are any in the chunk + if delta.tool_calls: + history[-1]["tool_calls"] = delta.tool_calls + + if delta.content: + # append content to the last history item + history[-1]["content"] = history[-1].get("content", "") + delta.content + # yield content if it is from assistant + if history[-1]["role"] == "assistant": + return delta.content + + return None + + def get_arch_messages(response_json): arch_messages = [] if response_json and "metadata" in response_json: diff --git a/chatbot_ui/run_stream.py b/chatbot_ui/run_stream.py index bed5cdc0..56c8ed08 100644 --- a/chatbot_ui/run_stream.py +++ b/chatbot_ui/run_stream.py @@ -8,7 +8,7 @@ from typing import List, Optional, Tuple from openai import OpenAI from dotenv import load_dotenv -from common import get_prompt_targets +from common import get_prompt_targets, process_stream_chunk load_dotenv() @@ -42,11 +42,11 @@ client = OpenAI( ) -def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state): - if "history" not in state: - state["history"] = [] - - history = state.get("history") +def chat( + query: Optional[str], + conversation: Optional[List[Tuple[str, str]]], + history: List[dict], +): history.append({"role": "user", "content": query}) try: @@ -66,31 +66,14 @@ def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], st conversation.append((query, "")) for chunk in response: - message = chunk.choices[0].delta - if message.role and message.role != history[-1]["role"]: - # create new history item if role changes - # this is likely due to arch tool call and api response - history.append( - { - "role": message.role, - } - ) - - history[-1]["model"] = chunk.model - if message.tool_calls: - history[-1]["tool_calls"] = message.tool_calls - - if message.content: - history[-1]["content"] = history[-1].get("content", "") + message.content - - # message.content is none for tool calls - # when "role = tool" content would contain api call response - if message.content and history[-1]["role"] != "tool": + tokens = process_stream_chunk(chunk, history) + if tokens: conversation[-1] = ( conversation[-1][0], - conversation[-1][1] + message.content, + conversation[-1][1] + tokens, ) - yield "", conversation, state + + yield "", conversation, history def main(): @@ -102,7 +85,7 @@ def main(): css=CSS_STYLE, ) as demo: with gr.Row(equal_height=True): - state = gr.State({}) + state = gr.State([]) with gr.Column(scale=1): with gr.Accordion("See available tools", open=False):