mirror of
https://github.com/katanemo/plano.git
synced 2026-06-20 15:28:07 +02:00
add type check and length checl
This commit is contained in:
parent
5bf9a80283
commit
770ebbdd4e
1 changed files with 64 additions and 10 deletions
|
|
@ -356,6 +356,17 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
|
|
||||||
return {"result": tool_calls, "status": is_valid, "message": error_message}
|
return {"result": tool_calls, "status": is_valid, "message": error_message}
|
||||||
|
|
||||||
|
def _correcting_type(value, target_type):
|
||||||
|
try:
|
||||||
|
if target_type == float and isinstance(value, int):
|
||||||
|
return float(value)
|
||||||
|
elif target_type == list and isinstance(value, str):
|
||||||
|
return ast.literal_eval(value)
|
||||||
|
# Add more conversion rules as needed
|
||||||
|
except (ValueError, TypeError, json.JSONDecodeError):
|
||||||
|
pass
|
||||||
|
return value
|
||||||
|
|
||||||
def _verify_tool_calls(
|
def _verify_tool_calls(
|
||||||
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
|
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
|
||||||
) -> Dict[str, any]:
|
) -> Dict[str, any]:
|
||||||
|
|
@ -410,7 +421,8 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
|
|
||||||
if data_type in self.support_data_types:
|
if data_type in self.support_data_types:
|
||||||
if not isinstance(
|
if not isinstance(
|
||||||
param_value, self.support_data_types[data_type]
|
self._correcting_type(param_value),
|
||||||
|
self.support_data_types[data_type],
|
||||||
):
|
):
|
||||||
is_valid = False
|
is_valid = False
|
||||||
invalid_tool_call = tool_call
|
invalid_tool_call = tool_call
|
||||||
|
|
@ -457,6 +469,48 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
)
|
)
|
||||||
return prefill_response
|
return prefill_response
|
||||||
|
|
||||||
|
def _check_length_and_pop_messages(messages, max_tokens=4096):
|
||||||
|
"""
|
||||||
|
Trims the `messages` list to ensure the total token count does not exceed `max_tokens`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list): List of message dictionaries.
|
||||||
|
max_tokens (int): Maximum allowed token count.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Trimmed list of messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def estimate_token_length(messages):
|
||||||
|
"""Estimate the total token length of the messages."""
|
||||||
|
total_tokens = 0
|
||||||
|
for message in messages:
|
||||||
|
# Approximate token length: assuming ~4 characters per token on average
|
||||||
|
total_tokens += len(message["content"]) // 4
|
||||||
|
return total_tokens
|
||||||
|
|
||||||
|
# Calculate initial token length
|
||||||
|
total_tokens = estimate_token_length(messages)
|
||||||
|
|
||||||
|
# Trim messages if token count exceeds the limit
|
||||||
|
while total_tokens > max_tokens:
|
||||||
|
# Find the first non-system message pair
|
||||||
|
for i in range(len(messages)):
|
||||||
|
if messages[i]["role"] != "system":
|
||||||
|
# Remove the 'user'/'assistant' pair
|
||||||
|
if i + 1 < len(messages) and messages[i + 1]["role"] in [
|
||||||
|
"user",
|
||||||
|
"assistant",
|
||||||
|
]:
|
||||||
|
del messages[i : i + 2]
|
||||||
|
else:
|
||||||
|
del messages[i]
|
||||||
|
break
|
||||||
|
# Recalculate token length
|
||||||
|
total_tokens = estimate_token_length(messages)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||||
"""
|
"""
|
||||||
|
|
@ -465,7 +519,6 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
Args:
|
Args:
|
||||||
req (ChatMessage): A chat message request object.
|
req (ChatMessage): A chat message request object.
|
||||||
enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True.
|
enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ChatCompletionResponse: The model's response to the chat request.
|
ChatCompletionResponse: The model's response to the chat request.
|
||||||
|
|
||||||
|
|
@ -474,6 +527,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages = self._process_messages(req.messages, req.tools)
|
messages = self._process_messages(req.messages, req.tools)
|
||||||
|
messages = self._check_length_and_pop_messages(messages)
|
||||||
|
|
||||||
# always enable `stream=True` to collect model responses
|
# always enable `stream=True` to collect model responses
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
|
|
@ -488,15 +542,15 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
response_iterator=response, function=req.tools
|
response_iterator=response, function=req.tools
|
||||||
)
|
)
|
||||||
|
|
||||||
model_response, has_tool_call = "", None
|
model_response, self.has_tool_call = "", None
|
||||||
|
|
||||||
for _ in self.hallu_handler:
|
for _ in self.hallu_handler:
|
||||||
# check if the first token is <tool_call>
|
# check if the first token is <tool_call>
|
||||||
if len(self.hallu_handler.tokens) > 0 and has_tool_call is None:
|
if len(self.hallu_handler.tokens) > 0 and self.has_tool_call is None:
|
||||||
if self.hallu_handler.tokens[0] == "<tool_call>":
|
if self.hallu_handler.tokens[0] == "<tool_call>":
|
||||||
has_tool_call = True
|
self.has_tool_call = True
|
||||||
else:
|
else:
|
||||||
has_tool_call = False
|
self.has_tool_call = False
|
||||||
break
|
break
|
||||||
|
|
||||||
# if the model is hallucinating, start parameter gathering
|
# if the model is hallucinating, start parameter gathering
|
||||||
|
|
@ -512,13 +566,13 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
model_response = prefill_response.choices[0].message.content
|
model_response = prefill_response.choices[0].message.content
|
||||||
break
|
break
|
||||||
|
|
||||||
if has_tool_call and self.hallu_handler.hallucination is False:
|
if self.has_tool_call and self.hallu_handler.hallucination is False:
|
||||||
# [TODO] - Review: remove the following code
|
# [TODO] - Review: remove the following code
|
||||||
print("Tool call found, no hallucination detected!")
|
print("Tool call found, no hallucination detected!")
|
||||||
model_response = "".join(self.hallu_handler.tokens)
|
model_response = "".join(self.hallu_handler.tokens)
|
||||||
|
|
||||||
# start parameter gathering if the model is not generating tool calls
|
# start parameter gathering if the model is not generating tool calls
|
||||||
if has_tool_call is False:
|
if self.has_tool_call is False:
|
||||||
# [TODO] - Review: remove the following code
|
# [TODO] - Review: remove the following code
|
||||||
print("No tool call found, start parameter gathering")
|
print("No tool call found, start parameter gathering")
|
||||||
print(f"Token entropy/varentropy map: {self.hallu_handler.token_probs_map}")
|
print(f"Token entropy/varentropy map: {self.hallu_handler.token_probs_map}")
|
||||||
|
|
@ -528,7 +582,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
# Extract tool calls from model response
|
# Extract tool calls from model response
|
||||||
extracted = self._extract_tool_calls(model_response)
|
extracted = self._extract_tool_calls(model_response)
|
||||||
# [TODO] - Review: remvoe the following code
|
# [TODO] - Review: remvoe the following code
|
||||||
print(f"[Extracted] - {extracted}")
|
# print(f"[Extracted] - {extracted}")
|
||||||
|
|
||||||
if len(extracted["result"]) and extracted["status"]:
|
if len(extracted["result"]) and extracted["status"]:
|
||||||
# [TODO] Review: define the behavior in the case that tool call extraction fails
|
# [TODO] Review: define the behavior in the case that tool call extraction fails
|
||||||
|
|
@ -538,7 +592,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
tools=req.tools, tool_calls=extracted["result"]
|
tools=req.tools, tool_calls=extracted["result"]
|
||||||
)
|
)
|
||||||
# [TODO] - Review: remvoe the following code
|
# [TODO] - Review: remvoe the following code
|
||||||
print(f"[Verified] - {verified}")
|
# print(f"[Verified] - {verified}")
|
||||||
|
|
||||||
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
|
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
|
||||||
if verified["status"]:
|
if verified["status"]:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue