Cotran/integration (#341)

* add hallucination

* add test and fix bug
This commit is contained in:
CTran 2024-12-09 13:30:52 -08:00 committed by GitHub
parent 8f1b21124b
commit 9dd7f15eab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 154 additions and 1085 deletions

View file

@ -13,6 +13,7 @@ from src.core.model_utils import (
ChatCompletionResponse,
ArchBaseHandler,
)
from src.core.hallucination import HallucinationStateHandler
class ArchIntentConfig:
@ -178,6 +179,8 @@ class ArchFunctionConfig:
"top_k": 50,
"max_tokens": 512,
"stop_token_ids": [151645],
"logprobs": True,
"top_logprobs": 10,
}
PREFILL_CONFIG = {
@ -391,7 +394,8 @@ class ArchFunctionHandler(ArchBaseHandler):
return is_valid, invalid_tool_call, error_message
# Verify the data type of each parameter in the tool calls
for param_name, param_value in func_args:
for param_name in func_args:
param_value = func_args[param_name]
data_type = functions[func_name]["properties"][param_name]["type"]
if data_type in self.support_data_types:
@ -427,6 +431,22 @@ class ArchFunctionHandler(ArchBaseHandler):
}
]
def _engage_parameter_gathering(self, messages: List[Dict[str, str]]):
"""
Engage parameter gathering for tool calls
"""
# TODO: log enaging parameter gathering
prefill_response = self.client.chat.completions.create(
messages=self._add_prefill_message(messages),
model=self.model_name,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
return prefill_response
@override
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
@ -453,32 +473,35 @@ class ArchFunctionHandler(ArchBaseHandler):
extra_body=self.generation_params,
)
# initialize the hallucination handler, which is an iterator
self.hallu_handler = HallucinationStateHandler(
response_iterator=response, function=req.tools
)
model_response, has_tool_call = "", None
for token in response:
token_content = token.choices[0].delta.content.strip()
if token_content:
if has_tool_call is None and token_content != "<tool_call>":
has_tool_call = False
response.close()
break
else:
for token in self.hallu_handler:
# check if the first token is <tool_call>
if len(self.hallu_handler.tokens) > 0 and has_tool_call == None:
if self.hallu_handler.tokens[0] == "<tool_call>":
has_tool_call = True
else:
has_tool_call = False
break
if has_tool_call is True:
model_response += token_content
# if the model is hallucinating, start parameter gathering
if self.hallu_handler.hallucination == True:
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
break
# start parameter gathering if the model is not generating tool calls
if self.hallu_handler.hallucination == False:
model_response = "".join(self.hallu_handler.tokens)
# start parameter gathering if the model is not generating tool calls
if has_tool_call is False:
prefill_response = self.client.chat.completions.create(
messages=self._add_prefill_message(messages),
model=self.model_name,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
# Extract tool calls from model response

View file

@ -27,10 +27,10 @@ class MaskToken(Enum):
HALLUCINATION_THRESHOLD_DICT = {
MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5},
MaskToken.TOOL_CALL.value: {"entropy": 0.001, "varentropy": 0.005},
MaskToken.PARAMETER_VALUE.value: {
"entropy": 0.5,
"varentropy": 2.5,
"entropy": 0.001,
"varentropy": 0.005,
},
}
@ -105,11 +105,10 @@ class HallucinationStateHandler:
hallucination (bool): Flag indicating if a hallucination is detected.
hallucination_message (str): Message describing the hallucination.
parameter_name (list): List of extracted parameter names.
function_description (dict): Description of functions and their parameters.
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
"""
def __init__(self, response_iterator=None):
def __init__(self, response_iterator=None, function=None):
"""
Initializes the HallucinationStateHandler with default values.
"""
@ -124,7 +123,15 @@ class HallucinationStateHandler:
self.parameter_name: List[str] = []
self.token_probs_map: List[Tuple[str, float, float]] = []
self.response_iterator = response_iterator
self.has_tool_call = False
self._process_function(function)
def _process_function(self, function):
self.function = function
if self.function is None:
raise ValueError("API descriptions not set.")
self.function_properties = {
x["function"]["name"]: x["function"]["parameters"] for x in self.function
}
def append_and_check_token_hallucination(self, token, logprob):
"""
@ -139,8 +146,7 @@ class HallucinationStateHandler:
"""
self.tokens.append(token)
self.logprobs.append(logprob)
if self.has_tool_call:
self._process_token()
self._process_token()
return self.hallucination
def __iter__(self):