This commit is contained in:
Adil Hafeez 2024-10-29 11:15:07 -07:00
parent 87805bce1d
commit 91df1e7941
2 changed files with 34 additions and 29 deletions

View file

@ -13,6 +13,28 @@ log = logging.getLogger(__name__)
ARCH_STATE_HEADER = "x-arch-state" ARCH_STATE_HEADER = "x-arch-state"
def process_stream_chunk(chunk, history):
delta = chunk.choices[0].delta
if delta.role and delta.role != history[-1]["role"]:
# create new history item if role changes
# this is likely due to arch tool call and api response
history.append({"role": delta.role})
history[-1]["model"] = chunk.model
# append tool calls to history if there are any in the chunk
if delta.tool_calls:
history[-1]["tool_calls"] = delta.tool_calls
if delta.content:
# append content to the last history item
history[-1]["content"] = history[-1].get("content", "") + delta.content
# yield content if it is from assistant
if history[-1]["role"] == "assistant":
return delta.content
return None
def get_arch_messages(response_json): def get_arch_messages(response_json):
arch_messages = [] arch_messages = []
if response_json and "metadata" in response_json: if response_json and "metadata" in response_json:

View file

@ -8,7 +8,7 @@ from typing import List, Optional, Tuple
from openai import OpenAI from openai import OpenAI
from dotenv import load_dotenv from dotenv import load_dotenv
from common import get_prompt_targets from common import get_prompt_targets, process_stream_chunk
load_dotenv() load_dotenv()
@ -42,11 +42,11 @@ client = OpenAI(
) )
def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state): def chat(
if "history" not in state: query: Optional[str],
state["history"] = [] conversation: Optional[List[Tuple[str, str]]],
history: List[dict],
history = state.get("history") ):
history.append({"role": "user", "content": query}) history.append({"role": "user", "content": query})
try: try:
@ -66,31 +66,14 @@ def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], st
conversation.append((query, "")) conversation.append((query, ""))
for chunk in response: for chunk in response:
message = chunk.choices[0].delta tokens = process_stream_chunk(chunk, history)
if message.role and message.role != history[-1]["role"]: if tokens:
# create new history item if role changes
# this is likely due to arch tool call and api response
history.append(
{
"role": message.role,
}
)
history[-1]["model"] = chunk.model
if message.tool_calls:
history[-1]["tool_calls"] = message.tool_calls
if message.content:
history[-1]["content"] = history[-1].get("content", "") + message.content
# message.content is none for tool calls
# when "role = tool" content would contain api call response
if message.content and history[-1]["role"] != "tool":
conversation[-1] = ( conversation[-1] = (
conversation[-1][0], conversation[-1][0],
conversation[-1][1] + message.content, conversation[-1][1] + tokens,
) )
yield "", conversation, state
yield "", conversation, history
def main(): def main():
@ -102,7 +85,7 @@ def main():
css=CSS_STYLE, css=CSS_STYLE,
) as demo: ) as demo:
with gr.Row(equal_height=True): with gr.Row(equal_height=True):
state = gr.State({}) state = gr.State([])
with gr.Column(scale=1): with gr.Column(scale=1):
with gr.Accordion("See available tools", open=False): with gr.Accordion("See available tools", open=False):