mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Cotran/intent (#339)
* add else * integrate hallucination * remove test
This commit is contained in:
parent
afec644789
commit
a40cdc7b75
2 changed files with 86 additions and 28 deletions
|
|
@ -12,6 +12,7 @@ from app.model_handler.base_handler import (
|
|||
ChatCompletionResponse,
|
||||
ArchBaseHandler,
|
||||
)
|
||||
from app.function_calling.hallucination_handler import HallucinationStateHandler
|
||||
|
||||
|
||||
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
||||
|
|
@ -79,8 +80,10 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
Returns:
|
||||
bool: A boolean value to indicate if any intent match with prompts or not
|
||||
"""
|
||||
|
||||
return content.choices[0].message.content == "Yes"
|
||||
if hasattr(content.choices[0].message, "content"):
|
||||
return content.choices[0].message.content == "Yes"
|
||||
else:
|
||||
return False
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
|
|
@ -322,7 +325,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
if required_param not in func_args:
|
||||
is_valid = False
|
||||
error_tool_call = tool_call
|
||||
error_message = f"`{required_param}` is requried by the function `{func_name}` but not found in the tool call!"
|
||||
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
|
||||
return is_valid, error_tool_call, error_message
|
||||
|
||||
# Verify the data type of each parameter in the tool calls
|
||||
|
|
@ -340,6 +343,36 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
return is_valid, error_tool_call, error_message
|
||||
|
||||
def _prefill_response(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
Prefills the response with the tool call prefix.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, str]]): A list of messages.
|
||||
tools (List[Dict[str, Any]]): A list of tools.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: A list of messages with the prefill prefix.
|
||||
"""
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": random.choice(self.prefill_prefix),
|
||||
}
|
||||
)
|
||||
prefill_response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
stream=False,
|
||||
extra_body={
|
||||
**self.generation_params,
|
||||
**self.prefill_params,
|
||||
},
|
||||
)
|
||||
|
||||
return prefill_response
|
||||
|
||||
@override
|
||||
async def chat_completion(
|
||||
self, req: ChatMessage, enable_prefilling=True
|
||||
|
|
@ -390,23 +423,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
# start parameter gathering if the model is not generating a tool call
|
||||
if has_tool_call is False:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": random.choice(self.prefill_prefix),
|
||||
}
|
||||
)
|
||||
|
||||
prefill_response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
stream=False,
|
||||
extra_body={
|
||||
**self.generation_params,
|
||||
**self.prefill_params,
|
||||
},
|
||||
)
|
||||
|
||||
prefill_response = self._prefill_response(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
else:
|
||||
model_response = response.choices[0].message.content
|
||||
|
|
|
|||
|
|
@ -72,6 +72,25 @@ def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
|
|||
return entropy.item(), varentropy.item()
|
||||
|
||||
|
||||
def is_parameter_required(
|
||||
function_description: Dict,
|
||||
parameter_name: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a parameter in required list
|
||||
|
||||
Args:
|
||||
function_description (dict): The API description in JSON format.
|
||||
parameter_name (str): The name of the parameter to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the parameter has the specified property, False otherwise.
|
||||
"""
|
||||
required_parameters = function_description.get("required", {})
|
||||
|
||||
return parameter_name in required_parameters
|
||||
|
||||
|
||||
class HallucinationStateHandler:
|
||||
"""
|
||||
A class to handle the state of hallucination detection in token processing.
|
||||
|
|
@ -104,6 +123,7 @@ class HallucinationStateHandler:
|
|||
self.parameter_name: List[str] = []
|
||||
self.token_probs_map: List[Tuple[str, float, float]] = []
|
||||
self.response_iterator = response_iterator
|
||||
self.has_tool_call = False
|
||||
|
||||
def append_and_check_token_hallucination(self, token, logprob):
|
||||
"""
|
||||
|
|
@ -118,7 +138,8 @@ class HallucinationStateHandler:
|
|||
"""
|
||||
self.tokens.append(token)
|
||||
self.logprobs.append(logprob)
|
||||
self._process_token()
|
||||
if self.has_tool_call:
|
||||
self._process_token()
|
||||
return self.hallucination
|
||||
|
||||
def __iter__(self):
|
||||
|
|
@ -164,7 +185,7 @@ class HallucinationStateHandler:
|
|||
self.mask.append(MaskToken.FUNCTION_NAME)
|
||||
else:
|
||||
self.state = None
|
||||
self._is_function_name_hallucinated()
|
||||
self._get_function_name()
|
||||
|
||||
# Check if the token is a function name start token, change the state
|
||||
if content.endswith(FUNC_NAME_START_PATTERN):
|
||||
|
|
@ -182,8 +203,8 @@ class HallucinationStateHandler:
|
|||
PARAMETER_NAME_END_TOKENS
|
||||
):
|
||||
self.state = None
|
||||
self._is_parameter_name_hallucinated()
|
||||
self.parameter_name_done = True
|
||||
self._get_parameter_name()
|
||||
# if the parameter name is done and the token is a parameter name start token, change the state
|
||||
elif self.parameter_name_done and content.endswith(
|
||||
PARAMETER_NAME_START_PATTERN
|
||||
|
|
@ -208,11 +229,10 @@ class HallucinationStateHandler:
|
|||
if (
|
||||
len(self.mask) > 1
|
||||
and self.mask[-2] != MaskToken.PARAMETER_VALUE
|
||||
# and not is_parameter_property(
|
||||
# self.function_properties[self.function_name],
|
||||
# self.parameter_name[-1],
|
||||
# "default",
|
||||
# )
|
||||
and is_parameter_required(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
)
|
||||
):
|
||||
self._check_logprob()
|
||||
else:
|
||||
|
|
@ -266,3 +286,24 @@ class HallucinationStateHandler:
|
|||
if self.mask and self.mask[-1] == token
|
||||
else 0
|
||||
)
|
||||
|
||||
def _get_parameter_name(self):
|
||||
"""
|
||||
Get the parameter name from the tokens.
|
||||
|
||||
Returns:
|
||||
str: The extracted parameter name.
|
||||
"""
|
||||
p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME)
|
||||
parameter_name = "".join(self.tokens[:-1][-p_len:])
|
||||
self.parameter_name.append(parameter_name)
|
||||
|
||||
def _get_function_name(self):
|
||||
"""
|
||||
Get the function name from the tokens.
|
||||
|
||||
Returns:
|
||||
str: The extracted function name.
|
||||
"""
|
||||
f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
|
||||
self.function_name = "".join(self.tokens[:-1][-f_len:])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue