mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fix more
This commit is contained in:
parent
87805bce1d
commit
91df1e7941
2 changed files with 34 additions and 29 deletions
|
|
@ -13,6 +13,28 @@ log = logging.getLogger(__name__)
|
||||||
ARCH_STATE_HEADER = "x-arch-state"
|
ARCH_STATE_HEADER = "x-arch-state"
|
||||||
|
|
||||||
|
|
||||||
|
def process_stream_chunk(chunk, history):
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta.role and delta.role != history[-1]["role"]:
|
||||||
|
# create new history item if role changes
|
||||||
|
# this is likely due to arch tool call and api response
|
||||||
|
history.append({"role": delta.role})
|
||||||
|
|
||||||
|
history[-1]["model"] = chunk.model
|
||||||
|
# append tool calls to history if there are any in the chunk
|
||||||
|
if delta.tool_calls:
|
||||||
|
history[-1]["tool_calls"] = delta.tool_calls
|
||||||
|
|
||||||
|
if delta.content:
|
||||||
|
# append content to the last history item
|
||||||
|
history[-1]["content"] = history[-1].get("content", "") + delta.content
|
||||||
|
# yield content if it is from assistant
|
||||||
|
if history[-1]["role"] == "assistant":
|
||||||
|
return delta.content
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_arch_messages(response_json):
|
def get_arch_messages(response_json):
|
||||||
arch_messages = []
|
arch_messages = []
|
||||||
if response_json and "metadata" in response_json:
|
if response_json and "metadata" in response_json:
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import List, Optional, Tuple
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from common import get_prompt_targets
|
from common import get_prompt_targets, process_stream_chunk
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
@ -42,11 +42,11 @@ client = OpenAI(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state):
|
def chat(
|
||||||
if "history" not in state:
|
query: Optional[str],
|
||||||
state["history"] = []
|
conversation: Optional[List[Tuple[str, str]]],
|
||||||
|
history: List[dict],
|
||||||
history = state.get("history")
|
):
|
||||||
history.append({"role": "user", "content": query})
|
history.append({"role": "user", "content": query})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -66,31 +66,14 @@ def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], st
|
||||||
conversation.append((query, ""))
|
conversation.append((query, ""))
|
||||||
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
message = chunk.choices[0].delta
|
tokens = process_stream_chunk(chunk, history)
|
||||||
if message.role and message.role != history[-1]["role"]:
|
if tokens:
|
||||||
# create new history item if role changes
|
|
||||||
# this is likely due to arch tool call and api response
|
|
||||||
history.append(
|
|
||||||
{
|
|
||||||
"role": message.role,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
history[-1]["model"] = chunk.model
|
|
||||||
if message.tool_calls:
|
|
||||||
history[-1]["tool_calls"] = message.tool_calls
|
|
||||||
|
|
||||||
if message.content:
|
|
||||||
history[-1]["content"] = history[-1].get("content", "") + message.content
|
|
||||||
|
|
||||||
# message.content is none for tool calls
|
|
||||||
# when "role = tool" content would contain api call response
|
|
||||||
if message.content and history[-1]["role"] != "tool":
|
|
||||||
conversation[-1] = (
|
conversation[-1] = (
|
||||||
conversation[-1][0],
|
conversation[-1][0],
|
||||||
conversation[-1][1] + message.content,
|
conversation[-1][1] + tokens,
|
||||||
)
|
)
|
||||||
yield "", conversation, state
|
|
||||||
|
yield "", conversation, history
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
@ -102,7 +85,7 @@ def main():
|
||||||
css=CSS_STYLE,
|
css=CSS_STYLE,
|
||||||
) as demo:
|
) as demo:
|
||||||
with gr.Row(equal_height=True):
|
with gr.Row(equal_height=True):
|
||||||
state = gr.State({})
|
state = gr.State([])
|
||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
with gr.Accordion("See available tools", open=False):
|
with gr.Accordion("See available tools", open=False):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue