mirror of
https://github.com/katanemo/plano.git
synced 2026-06-23 15:38:07 +02:00
new implemenetation
This commit is contained in:
parent
665dbc2d4e
commit
abfc81b0e7
3 changed files with 224 additions and 340 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from app.function_calling.hallucination_handler import hallucination_detect
|
||||
from app.function_calling.hallucination_handler import HallucinationStateHandler
|
||||
import pytest
|
||||
import os
|
||||
|
||||
|
|
@ -44,27 +44,15 @@ function_description = get_weather_api["function"]
|
|||
if type(function_description) != list:
|
||||
function_description = [get_weather_api["function"]]
|
||||
|
||||
parameter_names = {}
|
||||
for func in function_description:
|
||||
func_name = func["name"]
|
||||
parameters = func["parameters"]["properties"]
|
||||
parameter_names[func_name] = list(parameters.keys())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", test_cases)
|
||||
def test_hallucination(case):
|
||||
current_state = {
|
||||
"state": "start",
|
||||
"tool_call": "",
|
||||
"entropy": [],
|
||||
"varentropy": [],
|
||||
"logprobs": [],
|
||||
"tokens": [],
|
||||
"content": "",
|
||||
"hallucination": False,
|
||||
"parameter_names": parameter_names,
|
||||
"function_description": function_description,
|
||||
}
|
||||
for token_content, logprobs in zip(case["tokens"], case["logprobs"]):
|
||||
result = hallucination_detect(token_content, logprobs, current_state, 0.7, 4)
|
||||
assert result == case["expect"]
|
||||
state = HallucinationStateHandler()
|
||||
state.process_function(function_description)
|
||||
for token, logprob in zip(case["tokens"], case["logprobs"]):
|
||||
if token != "</tool_call>":
|
||||
state.current_token = token
|
||||
state.tokens.append(token)
|
||||
state.logprobs.append(logprob)
|
||||
state.process_token()
|
||||
assert state.hallucination == case["expect"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue