mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
format fix
This commit is contained in:
parent
1b39ee3dd8
commit
0c3d52bfe4
3 changed files with 17 additions and 17 deletions
|
|
@ -272,9 +272,9 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
if required_param not in func_args:
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
|
||||
verification_dict["error_message"] = (
|
||||
f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
|
||||
)
|
||||
break
|
||||
|
||||
# Verify the data type of each parameter in the tool calls
|
||||
|
|
@ -286,9 +286,9 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
if param_name not in function_properties:
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
|
||||
verification_dict["error_message"] = (
|
||||
f"Parameter `{param_name}` is not defined in the function `{func_name}`."
|
||||
)
|
||||
break
|
||||
else:
|
||||
param_value = func_args[param_name]
|
||||
|
|
@ -304,16 +304,16 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
if not isinstance(param_value, data_type):
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
|
||||
verification_dict["error_message"] = (
|
||||
f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
|
||||
)
|
||||
break
|
||||
else:
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"Data type `{target_type}` is not supported."
|
||||
verification_dict["error_message"] = (
|
||||
f"Data type `{target_type}` is not supported."
|
||||
)
|
||||
|
||||
return verification_dict
|
||||
|
||||
|
|
@ -376,8 +376,8 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
has_tool_calls, has_hallucination = None, False
|
||||
for _ in self.hallucination_state:
|
||||
# check if moodel response starts with tool calls
|
||||
if len(self.hallucination_state.tokens)>5 and has_tool_calls is None:
|
||||
# check if moodel response starts with tool calls, we do it after 5 tokens because we only check the first part of the response.
|
||||
if len(self.hallucination_state.tokens) > 5 and has_tool_calls is None:
|
||||
content = "".join(self.hallucination_state.tokens)
|
||||
if "tool_calls" in content:
|
||||
has_tool_calls = True
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ class HallucinationState:
|
|||
r = next(self.response_iterator)
|
||||
if hasattr(r.choices[0].delta, "content"):
|
||||
token_content = r.choices[0].delta.content
|
||||
if token_content != '':
|
||||
if token_content != "":
|
||||
try:
|
||||
logprobs = [
|
||||
p.logprob
|
||||
|
|
@ -214,7 +214,7 @@ class HallucinationState:
|
|||
self.append_and_check_token_hallucination(
|
||||
token_content, [None]
|
||||
)
|
||||
|
||||
|
||||
return token_content
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
|
|
|||
|
|
@ -110,6 +110,6 @@ async def test_function_calling(get_data_func):
|
|||
final_response = await model_handler.chat_completion(req)
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
assert intent == (len(final_response.choices[0].message.tool_calls)>=1)
|
||||
assert intent == (len(final_response.choices[0].message.tool_calls) >= 1)
|
||||
|
||||
assert hallucination == model_handler.hallucination_state.hallucination
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue