fix comments

This commit is contained in:
Adil Hafeez 2024-10-28 21:16:53 -07:00
parent eaa99259ad
commit 676617b41b

View file

@ -87,7 +87,7 @@ def get_prompt_targets():
return None
def chat(query: Optional[str], messages: Optional[List[Tuple[str, str]]], state):
def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state):
if "history" not in state:
state["history"] = []
@ -119,7 +119,7 @@ def chat(query: Optional[str], messages: Optional[List[Tuple[str, str]]], state)
if STREAM_RESPONSE:
response = raw_response.parse()
history.append({"role": "assistant", "content": "", "model": ""})
messages.append((query, ""))
conversation.append((query, ""))
# for gradio UI we don't want to show raw tool calls and messages from developer application
# so we're filtering those out
history_view = [h for h in history if h["role"] != "tool" and "content" in h]
@ -128,10 +128,9 @@ def chat(query: Optional[str], messages: Optional[List[Tuple[str, str]]], state)
print("chunk: " + str(chunk.to_dict()))
if len(chunk.choices) > 0:
if chunk.choices[0].delta.role:
print("role (hist): " + chunk.choices[0].delta.role)
print("role (resp): " + chunk.choices[0].delta.role)
# create new history item if role changes
# this is likely due to arch tool call and api response
if history[-1]["role"] != chunk.choices[0].delta.role:
print("creating new history item: " + str(chunk.choices[0]))
history.append(
{
"role": chunk.choices[0].delta.role,
@ -152,12 +151,12 @@ def chat(query: Optional[str], messages: Optional[List[Tuple[str, str]]], state)
history[-1]["tool_calls"] = chunk.choices[0].delta.tool_calls
if history[-1]["role"] != "tool":
if chunk.model and chunk.choices[0].delta.content:
messages[-1] = (
messages[-1][0],
messages[-1][1] + chunk.choices[0].delta.content,
if chunk.model and chunk.choices[0].delta.content != "":
conversation[-1] = (
conversation[-1][0],
conversation[-1][1] + chunk.choices[0].delta.content,
)
yield "", messages, state
yield "", conversation, state
else:
log.error(f"raw_response: {raw_response.text}")
response = raw_response.parse()
@ -181,12 +180,12 @@ def chat(query: Optional[str], messages: Optional[List[Tuple[str, str]]], state)
# so we're filtering those out
history_view = [h for h in history if h["role"] != "tool" and "content" in h]
messages = [
conversation = [
(history_view[i]["content"], history_view[i + 1]["content"])
for i in range(0, len(history_view) - 1, 2)
]
yield "", messages, state
yield "", conversation, state
def main():