mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add hallucination
This commit is contained in:
parent
e0d4ee7357
commit
423cfc0872
2 changed files with 62 additions and 38 deletions
|
|
@ -13,6 +13,7 @@ from src.core.model_utils import (
|
|||
ChatCompletionResponse,
|
||||
ArchBaseHandler,
|
||||
)
|
||||
from src.core.hallucination import HallucinationStateHandler
|
||||
|
||||
|
||||
class ArchIntentConfig:
|
||||
|
|
@ -172,15 +173,15 @@ class ArchFunctionConfig:
|
|||
"""
|
||||
).strip()
|
||||
|
||||
GENERATION_PARAMS = (
|
||||
{
|
||||
"temperature": 0.2,
|
||||
"top_p": 1.0,
|
||||
"top_k": 50,
|
||||
"max_tokens": 512,
|
||||
"stop_token_ids": [151645],
|
||||
},
|
||||
)
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.2,
|
||||
"top_p": 1.0,
|
||||
"top_k": 50,
|
||||
"max_tokens": 512,
|
||||
"stop_token_ids": [151645],
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
PREFILL_CONFIG = {
|
||||
"prefill_params": {
|
||||
|
|
@ -429,6 +430,20 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
}
|
||||
]
|
||||
|
||||
def _engage_parameter_gathering(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
Engage parameter gathering for tool calls
|
||||
"""
|
||||
prefill_response = self.client.chat.completions.create(
|
||||
messages=self._add_prefill_message(messages),
|
||||
model=self.model_name,
|
||||
extra_body={
|
||||
**self.generation_params,
|
||||
**self.prefill_params,
|
||||
},
|
||||
)
|
||||
return prefill_response
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
"""
|
||||
|
|
@ -454,49 +469,47 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
stream=True,
|
||||
extra_body=self.generation_params,
|
||||
)
|
||||
hallu_handler = HallucinationStateHandler(
|
||||
response_iterator=response, function=req.tools
|
||||
)
|
||||
|
||||
model_response, has_tool_call = "", None
|
||||
|
||||
for token in response:
|
||||
token_content = token.choices[0].delta.content.strip()
|
||||
|
||||
if token_content:
|
||||
if has_tool_call is None and token_content != "<tool_call>":
|
||||
has_tool_call = False
|
||||
response.close()
|
||||
break
|
||||
else:
|
||||
for token in hallu_handler:
|
||||
if len(hallu_handler.tokens) > 0 and has_tool_call == False:
|
||||
if hallu_handler.tokens[-0] == "<tool_call>":
|
||||
has_tool_call = True
|
||||
else:
|
||||
has_tool_call = False
|
||||
break
|
||||
if hallu_handler.hallucination == True:
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
break
|
||||
|
||||
if has_tool_call is True:
|
||||
model_response += token_content
|
||||
# start parameter gathering if the model is not generating tool calls
|
||||
if hallu_handler.hallucination == False:
|
||||
model_response = "".join(hallu_handler.tokens)
|
||||
|
||||
# start parameter gathering if the model is not generating tool calls
|
||||
if has_tool_call is False:
|
||||
prefill_response = self.client.chat.completions.create(
|
||||
messages=self._add_prefill_message(messages),
|
||||
model=self.model_name,
|
||||
extra_body={
|
||||
**self.generation_params,
|
||||
**self.prefill_params,
|
||||
},
|
||||
)
|
||||
prefill_response = await self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
|
||||
# Extract tool calls from model response
|
||||
extracted = self._extract_tool_calls(model_response)
|
||||
|
||||
if extracted["tool_calls"]:
|
||||
if extracted["result"]:
|
||||
# [TODO] Review: define the behavior in the case that tool call extraction fails
|
||||
# if not extracted["status"]:
|
||||
|
||||
verified = self._verify_tool_calls(
|
||||
tools=req.tools, tool_calls=extracted["tool_calls"]
|
||||
tools=req.tools, tool_calls=extracted["result"]
|
||||
)
|
||||
|
||||
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
|
||||
if verified["status"]:
|
||||
model_response = Message(content="", tool_calls=extracted["tool_calls"])
|
||||
model_response = Message(content="", tool_calls=extracted["result"])
|
||||
# else:
|
||||
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -27,10 +27,10 @@ class MaskToken(Enum):
|
|||
|
||||
|
||||
HALLUCINATION_THRESHOLD_DICT = {
|
||||
MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5},
|
||||
MaskToken.TOOL_CALL.value: {"entropy": 0.05, "varentropy": 0.25},
|
||||
MaskToken.PARAMETER_VALUE.value: {
|
||||
"entropy": 0.5,
|
||||
"varentropy": 2.5,
|
||||
"entropy": 0.05,
|
||||
"varentropy": 0.25,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -109,7 +109,7 @@ class HallucinationStateHandler:
|
|||
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
|
||||
"""
|
||||
|
||||
def __init__(self, response_iterator=None):
|
||||
def __init__(self, response_iterator=None, function=None):
|
||||
"""
|
||||
Initializes the HallucinationStateHandler with default values.
|
||||
"""
|
||||
|
|
@ -124,7 +124,19 @@ class HallucinationStateHandler:
|
|||
self.parameter_name: List[str] = []
|
||||
self.token_probs_map: List[Tuple[str, float, float]] = []
|
||||
self.response_iterator = response_iterator
|
||||
self.has_tool_call = False
|
||||
self._process_function(function)
|
||||
|
||||
def _process_function(self, function):
|
||||
self.function = function
|
||||
if self.function is None:
|
||||
raise ValueError("API descriptions not set.")
|
||||
parameter_names = {}
|
||||
for func in self.function:
|
||||
func_name = func["name"]
|
||||
parameters = func["parameters"]["properties"]
|
||||
parameter_names[func_name] = list(parameters.keys())
|
||||
self.function_description = parameter_names
|
||||
self.function_properties = {x["name"]: x["parameters"] for x in self.function}
|
||||
|
||||
def append_and_check_token_hallucination(self, token, logprob):
|
||||
"""
|
||||
|
|
@ -139,8 +151,7 @@ class HallucinationStateHandler:
|
|||
"""
|
||||
self.tokens.append(token)
|
||||
self.logprobs.append(logprob)
|
||||
if self.has_tool_call:
|
||||
self._process_token()
|
||||
self._process_token()
|
||||
return self.hallucination
|
||||
|
||||
def __iter__(self):
|
||||
Loading…
Add table
Add a link
Reference in a new issue