2024-11-18 00:09:02 -08:00
|
|
|
import json
|
|
|
|
|
from app.function_calling.hallucination_handler import hallucination_detect
|
|
|
|
|
import pytest
|
2024-11-18 00:53:49 -08:00
|
|
|
import os
|
2024-11-18 00:09:02 -08:00
|
|
|
|
2024-11-18 00:53:49 -08:00
|
|
|
# Get the directory of the current file
|
|
|
|
|
current_dir = os.path.dirname(__file__)
|
|
|
|
|
|
|
|
|
|
# Construct the full path to the JSON file
|
|
|
|
|
json_file_path = os.path.join(current_dir, "test_cases.json")
|
|
|
|
|
|
|
|
|
|
with open(json_file_path) as f:
|
|
|
|
|
test_cases = json.load(f)
|
2024-11-18 00:09:02 -08:00
|
|
|
|
|
|
|
|
get_weather_api = {
|
|
|
|
|
"type": "function",
|
|
|
|
|
"function": {
|
|
|
|
|
"name": "get_current_weather",
|
|
|
|
|
"description": "Get current weather at a location.",
|
|
|
|
|
"parameters": {
|
|
|
|
|
"type": "object",
|
|
|
|
|
"properties": {
|
|
|
|
|
"location": {
|
|
|
|
|
"type": "str",
|
|
|
|
|
"description": "The location to get the weather for",
|
|
|
|
|
"format": "City, State",
|
|
|
|
|
},
|
|
|
|
|
"unit": {
|
|
|
|
|
"type": "str",
|
|
|
|
|
"description": "The unit to return the weather in.",
|
|
|
|
|
"enum": ["celsius", "fahrenheit"],
|
|
|
|
|
"default": "celsius",
|
|
|
|
|
},
|
|
|
|
|
"days": {
|
|
|
|
|
"type": "str",
|
|
|
|
|
"description": "the number of days for the request.",
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
"required": ["location", "days"],
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
function_description = get_weather_api["function"]
|
|
|
|
|
if type(function_description) != list:
|
|
|
|
|
function_description = [get_weather_api["function"]]
|
|
|
|
|
|
|
|
|
|
parameter_names = {}
|
|
|
|
|
for func in function_description:
|
|
|
|
|
func_name = func["name"]
|
|
|
|
|
parameters = func["parameters"]["properties"]
|
|
|
|
|
parameter_names[func_name] = list(parameters.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("case", test_cases)
|
|
|
|
|
def test_hallucination(case):
|
|
|
|
|
current_state = {
|
|
|
|
|
"state": "start",
|
|
|
|
|
"tool_call": "",
|
|
|
|
|
"entropy": [],
|
|
|
|
|
"varentropy": [],
|
|
|
|
|
"logprobs": [],
|
|
|
|
|
"tokens": [],
|
|
|
|
|
"content": "",
|
|
|
|
|
"hallucination": False,
|
|
|
|
|
"parameter_names": parameter_names,
|
|
|
|
|
"function_description": function_description,
|
|
|
|
|
}
|
|
|
|
|
for token_content, logprobs in zip(case["tokens"], case["logprobs"]):
|
|
|
|
|
result = hallucination_detect(token_content, logprobs, current_state, 0.7, 4)
|
|
|
|
|
assert result == case["expect"]
|