new implemenetation

This commit is contained in:
cotran 2024-11-22 11:11:26 -08:00
parent 665dbc2d4e
commit abfc81b0e7
3 changed files with 224 additions and 340 deletions

View file

@ -1,5 +1,5 @@
import json
from app.function_calling.hallucination_handler import hallucination_detect
from app.function_calling.hallucination_handler import HallucinationStateHandler
import pytest
import os
@ -44,27 +44,15 @@ function_description = get_weather_api["function"]
if type(function_description) != list:
function_description = [get_weather_api["function"]]
parameter_names = {}
for func in function_description:
func_name = func["name"]
parameters = func["parameters"]["properties"]
parameter_names[func_name] = list(parameters.keys())
@pytest.mark.parametrize("case", test_cases)
def test_hallucination(case):
current_state = {
"state": "start",
"tool_call": "",
"entropy": [],
"varentropy": [],
"logprobs": [],
"tokens": [],
"content": "",
"hallucination": False,
"parameter_names": parameter_names,
"function_description": function_description,
}
for token_content, logprobs in zip(case["tokens"], case["logprobs"]):
result = hallucination_detect(token_content, logprobs, current_state, 0.7, 4)
assert result == case["expect"]
state = HallucinationStateHandler()
state.process_function(function_description)
for token, logprob in zip(case["tokens"], case["logprobs"]):
if token != "</tool_call>":
state.current_token = token
state.tokens.append(token)
state.logprobs.append(logprob)
state.process_token()
assert state.hallucination == case["expect"]