mirror of
https://github.com/katanemo/plano.git
synced 2026-04-26 09:16:24 +02:00
add prefill and test (#236)
* add prefill and test * fix stream * fix * feedback * address comments * update * add e2e test * fix e2e test * update fix * fix * address cmt * address cmt
This commit is contained in:
parent
f48489f7c0
commit
fb67788be0
7 changed files with 216 additions and 19 deletions
|
|
@ -1,21 +1,21 @@
|
|||
import json
|
||||
import hashlib
|
||||
import app.commons.constants as const
|
||||
|
||||
import random
|
||||
from fastapi import Response
|
||||
from pydantic import BaseModel
|
||||
from app.commons.utilities import get_model_server_logger
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
logger = get_model_server_logger()
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str = ""
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
tool_call_id: str = ""
|
||||
role: Optional[str] = ""
|
||||
content: Optional[str] = ""
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = []
|
||||
tool_call_id: Optional[str] = ""
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
|
@ -23,6 +23,20 @@ class ChatMessage(BaseModel):
|
|||
tools: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
message: Message
|
||||
finish_reason: Optional[str] = "stop"
|
||||
index: Optional[int] = 0
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
choices: List[Choice]
|
||||
model: Optional[str] = "Arch-Function"
|
||||
created: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
object: Optional[str] = "chat_completion"
|
||||
|
||||
|
||||
def process_messages(history: list[Message]):
|
||||
updated_history = []
|
||||
for hist in history:
|
||||
|
|
@ -67,30 +81,77 @@ async def chat_completion(req: ChatMessage, res: Response):
|
|||
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
|
||||
)
|
||||
|
||||
# Retrieve the first token, handling the Stream object carefully
|
||||
|
||||
try:
|
||||
resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=False,
|
||||
stream=const.PREFILL_ENABLED,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"model_server <= arch_function: error: {e}")
|
||||
raise
|
||||
|
||||
tool_calls = const.arch_function_hanlder.extract_tool_calls(
|
||||
resp.choices[0].message.content
|
||||
)
|
||||
if const.PREFILL_ENABLED:
|
||||
first_token_content = ""
|
||||
for token in resp:
|
||||
first_token_content = token.choices[
|
||||
0
|
||||
].delta.content.strip() # Clean up the content
|
||||
if first_token_content: # Break if it's non-empty
|
||||
break
|
||||
|
||||
# Check if the first token requires tool call handling
|
||||
if first_token_content != const.TOOL_CALL_TOKEN:
|
||||
# Engage pre-filling response if no tool call is indicated
|
||||
resp.close()
|
||||
logger.info("Tool call is not found! Engage pre filling")
|
||||
prefill_content = random.choice(const.PREFILL_LIST)
|
||||
messages.append({"role": "assistant", "content": prefill_content})
|
||||
|
||||
# Send a new completion request with the updated messages
|
||||
# the model will continue the final message in the chat instead of starting a new one
|
||||
# disable add_generation_prompt which tells the template to add tokens that indicate the start of a bot response.
|
||||
extra_body = {
|
||||
**const.arch_function_generation_params,
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False,
|
||||
}
|
||||
pre_fill_resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=False,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
full_response = pre_fill_resp.choices[0].message.content
|
||||
else:
|
||||
# Initialize full response and iterate over tokens to gather the full response
|
||||
full_response = first_token_content
|
||||
for token in resp:
|
||||
if hasattr(token.choices[0].delta, "content"):
|
||||
full_response += token.choices[0].delta.content
|
||||
else:
|
||||
logger.info("Stream is disabled, not engaging pre-filling")
|
||||
full_response = resp.choices[0].message.content
|
||||
|
||||
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)
|
||||
|
||||
if tool_calls:
|
||||
resp.choices[0].message.tool_calls = tool_calls
|
||||
resp.choices[0].message.content = None
|
||||
message = Message(content="", tool_calls=tool_calls)
|
||||
else:
|
||||
message = Message(content=full_response, tool_calls=[])
|
||||
choice = Choice(message=message)
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[choice], model=client_model_name
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
|
||||
)
|
||||
logger.info(
|
||||
f"model_server <= arch_function: response body: {json.dumps(resp.to_dict())}"
|
||||
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
|
||||
)
|
||||
|
||||
return resp
|
||||
return chat_completion_response
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue