mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add type check and length checl
This commit is contained in:
parent
5bf9a80283
commit
770ebbdd4e
1 changed files with 64 additions and 10 deletions
|
|
@ -356,6 +356,17 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
return {"result": tool_calls, "status": is_valid, "message": error_message}
|
||||
|
||||
def _correcting_type(value, target_type):
|
||||
try:
|
||||
if target_type == float and isinstance(value, int):
|
||||
return float(value)
|
||||
elif target_type == list and isinstance(value, str):
|
||||
return ast.literal_eval(value)
|
||||
# Add more conversion rules as needed
|
||||
except (ValueError, TypeError, json.JSONDecodeError):
|
||||
pass
|
||||
return value
|
||||
|
||||
def _verify_tool_calls(
|
||||
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
|
||||
) -> Dict[str, any]:
|
||||
|
|
@ -410,7 +421,8 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
if data_type in self.support_data_types:
|
||||
if not isinstance(
|
||||
param_value, self.support_data_types[data_type]
|
||||
self._correcting_type(param_value),
|
||||
self.support_data_types[data_type],
|
||||
):
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
|
|
@ -457,6 +469,48 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
)
|
||||
return prefill_response
|
||||
|
||||
def _check_length_and_pop_messages(messages, max_tokens=4096):
|
||||
"""
|
||||
Trims the `messages` list to ensure the total token count does not exceed `max_tokens`.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dictionaries.
|
||||
max_tokens (int): Maximum allowed token count.
|
||||
|
||||
Returns:
|
||||
list: Trimmed list of messages.
|
||||
"""
|
||||
|
||||
def estimate_token_length(messages):
|
||||
"""Estimate the total token length of the messages."""
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
# Approximate token length: assuming ~4 characters per token on average
|
||||
total_tokens += len(message["content"]) // 4
|
||||
return total_tokens
|
||||
|
||||
# Calculate initial token length
|
||||
total_tokens = estimate_token_length(messages)
|
||||
|
||||
# Trim messages if token count exceeds the limit
|
||||
while total_tokens > max_tokens:
|
||||
# Find the first non-system message pair
|
||||
for i in range(len(messages)):
|
||||
if messages[i]["role"] != "system":
|
||||
# Remove the 'user'/'assistant' pair
|
||||
if i + 1 < len(messages) and messages[i + 1]["role"] in [
|
||||
"user",
|
||||
"assistant",
|
||||
]:
|
||||
del messages[i : i + 2]
|
||||
else:
|
||||
del messages[i]
|
||||
break
|
||||
# Recalculate token length
|
||||
total_tokens = estimate_token_length(messages)
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
"""
|
||||
|
|
@ -465,7 +519,6 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
Args:
|
||||
req (ChatMessage): A chat message request object.
|
||||
enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True.
|
||||
|
||||
Returns:
|
||||
ChatCompletionResponse: The model's response to the chat request.
|
||||
|
||||
|
|
@ -474,6 +527,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
"""
|
||||
|
||||
messages = self._process_messages(req.messages, req.tools)
|
||||
messages = self._check_length_and_pop_messages(messages)
|
||||
|
||||
# always enable `stream=True` to collect model responses
|
||||
response = self.client.chat.completions.create(
|
||||
|
|
@ -488,15 +542,15 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
response_iterator=response, function=req.tools
|
||||
)
|
||||
|
||||
model_response, has_tool_call = "", None
|
||||
model_response, self.has_tool_call = "", None
|
||||
|
||||
for _ in self.hallu_handler:
|
||||
# check if the first token is <tool_call>
|
||||
if len(self.hallu_handler.tokens) > 0 and has_tool_call is None:
|
||||
if len(self.hallu_handler.tokens) > 0 and self.has_tool_call is None:
|
||||
if self.hallu_handler.tokens[0] == "<tool_call>":
|
||||
has_tool_call = True
|
||||
self.has_tool_call = True
|
||||
else:
|
||||
has_tool_call = False
|
||||
self.has_tool_call = False
|
||||
break
|
||||
|
||||
# if the model is hallucinating, start parameter gathering
|
||||
|
|
@ -512,13 +566,13 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
model_response = prefill_response.choices[0].message.content
|
||||
break
|
||||
|
||||
if has_tool_call and self.hallu_handler.hallucination is False:
|
||||
if self.has_tool_call and self.hallu_handler.hallucination is False:
|
||||
# [TODO] - Review: remove the following code
|
||||
print("Tool call found, no hallucination detected!")
|
||||
model_response = "".join(self.hallu_handler.tokens)
|
||||
|
||||
# start parameter gathering if the model is not generating tool calls
|
||||
if has_tool_call is False:
|
||||
if self.has_tool_call is False:
|
||||
# [TODO] - Review: remove the following code
|
||||
print("No tool call found, start parameter gathering")
|
||||
print(f"Token entropy/varentropy map: {self.hallu_handler.token_probs_map}")
|
||||
|
|
@ -528,7 +582,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
# Extract tool calls from model response
|
||||
extracted = self._extract_tool_calls(model_response)
|
||||
# [TODO] - Review: remvoe the following code
|
||||
print(f"[Extracted] - {extracted}")
|
||||
# print(f"[Extracted] - {extracted}")
|
||||
|
||||
if len(extracted["result"]) and extracted["status"]:
|
||||
# [TODO] Review: define the behavior in the case that tool call extraction fails
|
||||
|
|
@ -538,7 +592,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
tools=req.tools, tool_calls=extracted["result"]
|
||||
)
|
||||
# [TODO] - Review: remvoe the following code
|
||||
print(f"[Verified] - {verified}")
|
||||
# print(f"[Verified] - {verified}")
|
||||
|
||||
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
|
||||
if verified["status"]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue