mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
feedback
This commit is contained in:
parent
b15736a7fd
commit
cd22c71690
3 changed files with 53 additions and 43 deletions
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import hashlib
|
||||
import app.commons.constants as const
|
||||
|
||||
import random
|
||||
from fastapi import Response
|
||||
from pydantic import BaseModel
|
||||
from app.commons.utilities import get_model_server_logger
|
||||
|
|
@ -64,7 +64,9 @@ def process_messages(history: list[Message]):
|
|||
return updated_history
|
||||
|
||||
|
||||
async def chat_completion(req: ChatMessage, res: Response):
|
||||
async def chat_completion(
|
||||
req: ChatMessage, res: Response, prefill_enabled: bool = True
|
||||
):
|
||||
logger.info("starting request")
|
||||
|
||||
tools_encoded = const.arch_function_hanlder._format_system(req.tools)
|
||||
|
|
@ -81,55 +83,61 @@ async def chat_completion(req: ChatMessage, res: Response):
|
|||
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
|
||||
)
|
||||
|
||||
try:
|
||||
resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=True,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"model_server <= arch_function: error: {e}")
|
||||
raise
|
||||
|
||||
# Retrieve the first token, handling the Stream object carefully
|
||||
first_token_content = ""
|
||||
try:
|
||||
while True:
|
||||
first_token = next(resp) # Synchronously retrieve tokens
|
||||
first_token_content = first_token.choices[
|
||||
|
||||
if prefill_enabled:
|
||||
try:
|
||||
resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=True,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"model_server <= arch_function: error: {e}")
|
||||
raise
|
||||
first_token_content = ""
|
||||
for token in resp:
|
||||
first_token_content = token.choices[
|
||||
0
|
||||
].delta.content.strip() # Clean up the content
|
||||
if first_token_content: # Break if it's non-empty
|
||||
break
|
||||
except StopIteration:
|
||||
print("No non-empty tokens found.")
|
||||
return None
|
||||
|
||||
# Check if the first token requires tool call handling
|
||||
if first_token_content != "<tool_call>":
|
||||
# Engage pre-filling response if no tool call is indicated
|
||||
logger.info("Tool call is not found! Engage pre filling")
|
||||
messages.append({"role": "assistant", "content": "Sure!"})
|
||||
# Check if the first token requires tool call handling
|
||||
if first_token_content != "<tool_call>":
|
||||
# Engage pre-filling response if no tool call is indicated
|
||||
resp.close()
|
||||
logger.info("Tool call is not found! Engage pre filling")
|
||||
prefill_content = random.choice(const.prefill_list)
|
||||
messages.append({"role": "assistant", "content": prefill_content})
|
||||
|
||||
# Send a new completion request with the updated messages
|
||||
pre_fill_resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=False,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
full_response = pre_fill_resp.choices[0].message.content
|
||||
else:
|
||||
# Initialize full response and iterate over tokens to gather the full response
|
||||
full_response = "<tool_call>"
|
||||
try:
|
||||
while True:
|
||||
token = next(resp) # Retrieve each token synchronously
|
||||
# Send a new completion request with the updated messages
|
||||
pre_fill_resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=False,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
full_response = pre_fill_resp.choices[0].message.content
|
||||
else:
|
||||
# Initialize full response and iterate over tokens to gather the full response
|
||||
full_response = first_token_content
|
||||
for token in resp:
|
||||
if hasattr(token.choices[0].delta, "content"):
|
||||
full_response += token.choices[0].delta.content
|
||||
except StopIteration:
|
||||
pass # End of stream
|
||||
else:
|
||||
try:
|
||||
resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=False,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
full_response = resp.choices[0].message.content
|
||||
except Exception as e:
|
||||
logger.error(f"model_server <= arch_function: error: {e}")
|
||||
raise
|
||||
|
||||
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue