mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
fix error in function name + new thresholds (#406)
* fix error in function name + new thresholds * fix * fix * remove example * remove example
This commit is contained in:
parent
4ec03af16e
commit
e7b370cd2f
3 changed files with 28 additions and 35 deletions
|
|
@ -44,7 +44,7 @@ HALLUCINATION_THRESHOLD_DICT = {
|
||||||
},
|
},
|
||||||
MaskToken.PARAMETER_VALUE.value: {
|
MaskToken.PARAMETER_VALUE.value: {
|
||||||
"entropy": 0.28,
|
"entropy": 0.28,
|
||||||
"varentropy": 1.2,
|
"varentropy": 1.4,
|
||||||
"probability": 0.8,
|
"probability": 0.8,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -60,7 +60,7 @@ def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
|
||||||
thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'.
|
thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise.
|
bool: True if both the entropy and varentropy exceeds their respective thresholds, False otherwise.
|
||||||
"""
|
"""
|
||||||
return entropy > thd["entropy"] and varentropy > thd["varentropy"]
|
return entropy > thd["entropy"] and varentropy > thd["varentropy"]
|
||||||
|
|
||||||
|
|
@ -82,7 +82,7 @@ def calculate_uncertainty(log_probs: List[float]) -> Tuple[float, float]:
|
||||||
token_probs = torch.exp(log_probs)
|
token_probs = torch.exp(log_probs)
|
||||||
entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e)
|
entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e)
|
||||||
varentropy = torch.sum(
|
varentropy = torch.sum(
|
||||||
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
|
token_probs * (log_probs / math.log(2, math.e) + entropy.unsqueeze(-1)) ** 2,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
return entropy.item(), varentropy.item(), token_probs[0].item()
|
return entropy.item(), varentropy.item(), token_probs[0].item()
|
||||||
|
|
@ -303,22 +303,30 @@ class HallucinationState:
|
||||||
self.mask.append(MaskToken.PARAMETER_VALUE)
|
self.mask.append(MaskToken.PARAMETER_VALUE)
|
||||||
|
|
||||||
# checking if the parameter doesn't have enum and the token is the first parameter value token
|
# checking if the parameter doesn't have enum and the token is the first parameter value token
|
||||||
if (
|
# check if function name is in function properties
|
||||||
len(self.mask) > 1
|
if self.function_name in self.function_properties:
|
||||||
and self.mask[-2] != MaskToken.PARAMETER_VALUE
|
if (
|
||||||
and is_parameter_required(
|
len(self.mask) > 1
|
||||||
self.function_properties[self.function_name],
|
and self.mask[-2] != MaskToken.PARAMETER_VALUE
|
||||||
self.parameter_name[-1],
|
and is_parameter_required(
|
||||||
|
self.function_properties[self.function_name],
|
||||||
|
self.parameter_name[-1],
|
||||||
|
)
|
||||||
|
and not is_parameter_property(
|
||||||
|
self.function_properties[self.function_name],
|
||||||
|
self.parameter_name[-1],
|
||||||
|
"enum",
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if self.parameter_name[-1] not in self.check_parameter_name:
|
||||||
|
self._check_logprob()
|
||||||
|
self.check_parameter_name[self.parameter_name[-1]] = True
|
||||||
|
else:
|
||||||
|
self._check_logprob()
|
||||||
|
self.error_message = f"Function name {self.function_name} not found in function properties"
|
||||||
|
logger.warning(
|
||||||
|
f"Function name {self.function_name} not found in function properties"
|
||||||
)
|
)
|
||||||
and not is_parameter_property(
|
|
||||||
self.function_properties[self.function_name],
|
|
||||||
self.parameter_name[-1],
|
|
||||||
"enum",
|
|
||||||
)
|
|
||||||
):
|
|
||||||
if self.parameter_name[-1] not in self.check_parameter_name:
|
|
||||||
self._check_logprob()
|
|
||||||
self.check_parameter_name[self.parameter_name[-1]] = True
|
|
||||||
else:
|
else:
|
||||||
self.mask.append(MaskToken.NOT_USED)
|
self.mask.append(MaskToken.NOT_USED)
|
||||||
# if the state is parameter value and the token is an end token, change the state
|
# if the state is parameter value and the token is an end token, change the state
|
||||||
|
|
|
||||||
|
|
@ -54,20 +54,6 @@ def get_hallucination_data_complex():
|
||||||
return req, True, True, True
|
return req, True, True, True
|
||||||
|
|
||||||
|
|
||||||
def get_hallucination_data_easy():
|
|
||||||
# Create instances of the Message class
|
|
||||||
message1 = Message(role="user", content="How is the weather in Seattle?")
|
|
||||||
|
|
||||||
# Create a list of tools
|
|
||||||
tools = [get_weather_api]
|
|
||||||
|
|
||||||
# Create an instance of the ChatMessage class
|
|
||||||
req = ChatMessage(messages=[message1], tools=tools)
|
|
||||||
|
|
||||||
# model will hallucinate
|
|
||||||
return req, True, True, True
|
|
||||||
|
|
||||||
|
|
||||||
def get_hallucination_data_medium():
|
def get_hallucination_data_medium():
|
||||||
# Create instances of the Message class
|
# Create instances of the Message class
|
||||||
message1 = Message(role="user", content="How is the weather in?")
|
message1 = Message(role="user", content="How is the weather in?")
|
||||||
|
|
@ -142,7 +128,6 @@ def get_greeting_data():
|
||||||
"get_data_func",
|
"get_data_func",
|
||||||
[
|
[
|
||||||
get_hallucination_data_complex,
|
get_hallucination_data_complex,
|
||||||
get_hallucination_data_easy,
|
|
||||||
get_complete_data,
|
get_complete_data,
|
||||||
get_irrelevant_data,
|
get_irrelevant_data,
|
||||||
get_complete_data_2,
|
get_complete_data_2,
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ test_cases:
|
||||||
- role: "assistant"
|
- role: "assistant"
|
||||||
content: "Can you please provide me the days for the weather forecast?"
|
content: "Can you please provide me the days for the weather forecast?"
|
||||||
- role: "user"
|
- role: "user"
|
||||||
content: "los angeles in 5 days"
|
content: "5 days"
|
||||||
tools:
|
tools:
|
||||||
- type: "function"
|
- type: "function"
|
||||||
function:
|
function:
|
||||||
|
|
@ -82,7 +82,7 @@ test_cases:
|
||||||
required: ["location", "days"]
|
required: ["location", "days"]
|
||||||
expected:
|
expected:
|
||||||
- type: "metadata"
|
- type: "metadata"
|
||||||
hallucination: true
|
hallucination: false
|
||||||
|
|
||||||
- id: "[WEATHER AGENT] - multi turn, single tool, clarification"
|
- id: "[WEATHER AGENT] - multi turn, single tool, clarification"
|
||||||
input:
|
input:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue