add type check and length checl

This commit is contained in:
cotran 2024-12-11 13:33:38 -08:00
parent 5bf9a80283
commit 770ebbdd4e

View file

@ -356,6 +356,17 @@ class ArchFunctionHandler(ArchBaseHandler):
return {"result": tool_calls, "status": is_valid, "message": error_message}
def _correcting_type(value, target_type):
try:
if target_type == float and isinstance(value, int):
return float(value)
elif target_type == list and isinstance(value, str):
return ast.literal_eval(value)
# Add more conversion rules as needed
except (ValueError, TypeError, json.JSONDecodeError):
pass
return value
def _verify_tool_calls(
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
) -> Dict[str, any]:
@ -410,7 +421,8 @@ class ArchFunctionHandler(ArchBaseHandler):
if data_type in self.support_data_types:
if not isinstance(
param_value, self.support_data_types[data_type]
self._correcting_type(param_value),
self.support_data_types[data_type],
):
is_valid = False
invalid_tool_call = tool_call
@ -457,6 +469,48 @@ class ArchFunctionHandler(ArchBaseHandler):
)
return prefill_response
def _check_length_and_pop_messages(messages, max_tokens=4096):
"""
Trims the `messages` list to ensure the total token count does not exceed `max_tokens`.
Args:
messages (list): List of message dictionaries.
max_tokens (int): Maximum allowed token count.
Returns:
list: Trimmed list of messages.
"""
def estimate_token_length(messages):
"""Estimate the total token length of the messages."""
total_tokens = 0
for message in messages:
# Approximate token length: assuming ~4 characters per token on average
total_tokens += len(message["content"]) // 4
return total_tokens
# Calculate initial token length
total_tokens = estimate_token_length(messages)
# Trim messages if token count exceeds the limit
while total_tokens > max_tokens:
# Find the first non-system message pair
for i in range(len(messages)):
if messages[i]["role"] != "system":
# Remove the 'user'/'assistant' pair
if i + 1 < len(messages) and messages[i + 1]["role"] in [
"user",
"assistant",
]:
del messages[i : i + 2]
else:
del messages[i]
break
# Recalculate token length
total_tokens = estimate_token_length(messages)
return messages
@override
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
@ -465,7 +519,6 @@ class ArchFunctionHandler(ArchBaseHandler):
Args:
req (ChatMessage): A chat message request object.
enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True.
Returns:
ChatCompletionResponse: The model's response to the chat request.
@ -474,6 +527,7 @@ class ArchFunctionHandler(ArchBaseHandler):
"""
messages = self._process_messages(req.messages, req.tools)
messages = self._check_length_and_pop_messages(messages)
# always enable `stream=True` to collect model responses
response = self.client.chat.completions.create(
@ -488,15 +542,15 @@ class ArchFunctionHandler(ArchBaseHandler):
response_iterator=response, function=req.tools
)
model_response, has_tool_call = "", None
model_response, self.has_tool_call = "", None
for _ in self.hallu_handler:
# check if the first token is <tool_call>
if len(self.hallu_handler.tokens) > 0 and has_tool_call is None:
if len(self.hallu_handler.tokens) > 0 and self.has_tool_call is None:
if self.hallu_handler.tokens[0] == "<tool_call>":
has_tool_call = True
self.has_tool_call = True
else:
has_tool_call = False
self.has_tool_call = False
break
# if the model is hallucinating, start parameter gathering
@ -512,13 +566,13 @@ class ArchFunctionHandler(ArchBaseHandler):
model_response = prefill_response.choices[0].message.content
break
if has_tool_call and self.hallu_handler.hallucination is False:
if self.has_tool_call and self.hallu_handler.hallucination is False:
# [TODO] - Review: remove the following code
print("Tool call found, no hallucination detected!")
model_response = "".join(self.hallu_handler.tokens)
# start parameter gathering if the model is not generating tool calls
if has_tool_call is False:
if self.has_tool_call is False:
# [TODO] - Review: remove the following code
print("No tool call found, start parameter gathering")
print(f"Token entropy/varentropy map: {self.hallu_handler.token_probs_map}")
@ -528,7 +582,7 @@ class ArchFunctionHandler(ArchBaseHandler):
# Extract tool calls from model response
extracted = self._extract_tool_calls(model_response)
# [TODO] - Review: remvoe the following code
print(f"[Extracted] - {extracted}")
# print(f"[Extracted] - {extracted}")
if len(extracted["result"]) and extracted["status"]:
# [TODO] Review: define the behavior in the case that tool call extraction fails
@ -538,7 +592,7 @@ class ArchFunctionHandler(ArchBaseHandler):
tools=req.tools, tool_calls=extracted["result"]
)
# [TODO] - Review: remvoe the following code
print(f"[Verified] - {verified}")
# print(f"[Verified] - {verified}")
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
if verified["status"]: