This commit is contained in:
Adil Hafeez 2024-10-29 11:15:07 -07:00
parent 87805bce1d
commit 91df1e7941
2 changed files with 34 additions and 29 deletions

View file

@ -13,6 +13,28 @@ log = logging.getLogger(__name__)
ARCH_STATE_HEADER = "x-arch-state"
def process_stream_chunk(chunk, history):
delta = chunk.choices[0].delta
if delta.role and delta.role != history[-1]["role"]:
# create new history item if role changes
# this is likely due to arch tool call and api response
history.append({"role": delta.role})
history[-1]["model"] = chunk.model
# append tool calls to history if there are any in the chunk
if delta.tool_calls:
history[-1]["tool_calls"] = delta.tool_calls
if delta.content:
# append content to the last history item
history[-1]["content"] = history[-1].get("content", "") + delta.content
# yield content if it is from assistant
if history[-1]["role"] == "assistant":
return delta.content
return None
def get_arch_messages(response_json):
arch_messages = []
if response_json and "metadata" in response_json:

View file

@ -8,7 +8,7 @@ from typing import List, Optional, Tuple
from openai import OpenAI
from dotenv import load_dotenv
from common import get_prompt_targets
from common import get_prompt_targets, process_stream_chunk
load_dotenv()
@ -42,11 +42,11 @@ client = OpenAI(
)
def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state):
if "history" not in state:
state["history"] = []
history = state.get("history")
def chat(
query: Optional[str],
conversation: Optional[List[Tuple[str, str]]],
history: List[dict],
):
history.append({"role": "user", "content": query})
try:
@ -66,31 +66,14 @@ def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], st
conversation.append((query, ""))
for chunk in response:
message = chunk.choices[0].delta
if message.role and message.role != history[-1]["role"]:
# create new history item if role changes
# this is likely due to arch tool call and api response
history.append(
{
"role": message.role,
}
)
history[-1]["model"] = chunk.model
if message.tool_calls:
history[-1]["tool_calls"] = message.tool_calls
if message.content:
history[-1]["content"] = history[-1].get("content", "") + message.content
# message.content is none for tool calls
# when "role = tool" content would contain api call response
if message.content and history[-1]["role"] != "tool":
tokens = process_stream_chunk(chunk, history)
if tokens:
conversation[-1] = (
conversation[-1][0],
conversation[-1][1] + message.content,
conversation[-1][1] + tokens,
)
yield "", conversation, state
yield "", conversation, history
def main():
@ -102,7 +85,7 @@ def main():
css=CSS_STYLE,
) as demo:
with gr.Row(equal_height=True):
state = gr.State({})
state = gr.State([])
with gr.Column(scale=1):
with gr.Accordion("See available tools", open=False):