mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fix more
This commit is contained in:
parent
87805bce1d
commit
91df1e7941
2 changed files with 34 additions and 29 deletions
|
|
@ -13,6 +13,28 @@ log = logging.getLogger(__name__)
|
|||
ARCH_STATE_HEADER = "x-arch-state"
|
||||
|
||||
|
||||
def process_stream_chunk(chunk, history):
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.role and delta.role != history[-1]["role"]:
|
||||
# create new history item if role changes
|
||||
# this is likely due to arch tool call and api response
|
||||
history.append({"role": delta.role})
|
||||
|
||||
history[-1]["model"] = chunk.model
|
||||
# append tool calls to history if there are any in the chunk
|
||||
if delta.tool_calls:
|
||||
history[-1]["tool_calls"] = delta.tool_calls
|
||||
|
||||
if delta.content:
|
||||
# append content to the last history item
|
||||
history[-1]["content"] = history[-1].get("content", "") + delta.content
|
||||
# yield content if it is from assistant
|
||||
if history[-1]["role"] == "assistant":
|
||||
return delta.content
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_arch_messages(response_json):
|
||||
arch_messages = []
|
||||
if response_json and "metadata" in response_json:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import List, Optional, Tuple
|
|||
from openai import OpenAI
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from common import get_prompt_targets
|
||||
from common import get_prompt_targets, process_stream_chunk
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
|
@ -42,11 +42,11 @@ client = OpenAI(
|
|||
)
|
||||
|
||||
|
||||
def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state):
|
||||
if "history" not in state:
|
||||
state["history"] = []
|
||||
|
||||
history = state.get("history")
|
||||
def chat(
|
||||
query: Optional[str],
|
||||
conversation: Optional[List[Tuple[str, str]]],
|
||||
history: List[dict],
|
||||
):
|
||||
history.append({"role": "user", "content": query})
|
||||
|
||||
try:
|
||||
|
|
@ -66,31 +66,14 @@ def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], st
|
|||
conversation.append((query, ""))
|
||||
|
||||
for chunk in response:
|
||||
message = chunk.choices[0].delta
|
||||
if message.role and message.role != history[-1]["role"]:
|
||||
# create new history item if role changes
|
||||
# this is likely due to arch tool call and api response
|
||||
history.append(
|
||||
{
|
||||
"role": message.role,
|
||||
}
|
||||
)
|
||||
|
||||
history[-1]["model"] = chunk.model
|
||||
if message.tool_calls:
|
||||
history[-1]["tool_calls"] = message.tool_calls
|
||||
|
||||
if message.content:
|
||||
history[-1]["content"] = history[-1].get("content", "") + message.content
|
||||
|
||||
# message.content is none for tool calls
|
||||
# when "role = tool" content would contain api call response
|
||||
if message.content and history[-1]["role"] != "tool":
|
||||
tokens = process_stream_chunk(chunk, history)
|
||||
if tokens:
|
||||
conversation[-1] = (
|
||||
conversation[-1][0],
|
||||
conversation[-1][1] + message.content,
|
||||
conversation[-1][1] + tokens,
|
||||
)
|
||||
yield "", conversation, state
|
||||
|
||||
yield "", conversation, history
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -102,7 +85,7 @@ def main():
|
|||
css=CSS_STYLE,
|
||||
) as demo:
|
||||
with gr.Row(equal_height=True):
|
||||
state = gr.State({})
|
||||
state = gr.State([])
|
||||
|
||||
with gr.Column(scale=1):
|
||||
with gr.Accordion("See available tools", open=False):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue