address issues

This commit is contained in:
cotran 2024-11-25 14:45:29 -08:00
parent 4da2184d56
commit e44d189d86
2 changed files with 95 additions and 7 deletions

View file

@ -212,7 +212,7 @@ class HallucinationStateHandler:
if (
len(self.mask) > 1
and self.mask[-2] != "v"
and not check_parameter_property(
and not is_parameter_property(
self.function_properties[self.function_name],
self.parameter_name[-1],
"default",

View file

@ -47,14 +47,102 @@ if type(function_description) != list:
@pytest.mark.parametrize("case", test_cases)
def test_hallucination(case):
state = HallucinationStateHandler()
state.process_function(function_description)
state = HallucinationStateHandler(
response_iterator=None, function=function_description
)
for token, logprob in zip(case["tokens"], case["logprobs"]):
if token != "</tool_call>":
state.current_token = token
state.tokens.append(token)
state.logprobs.append(logprob)
state.process_token()
state.check_token_hallucination(token, logprob)
if state.hallucination:
break
assert state.hallucination == case["expect"]
@pytest.mark.parametrize("is_hallucinate_sample", [True, False])
def test_hallucination_prompt(is_hallucinate_sample):
TASK_PROMPT = """
You are a helpful assistant.
""".strip()
TOOL_PROMPT = """
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
""".strip()
FORMAT_PROMPT = """
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
""".strip()
def convert_tools(tools):
return "\n".join([json.dumps(tool) for tool in tools])
def format_prompt(tools):
tool_text = convert_tools(tools)
return (
TASK_PROMPT
+ "\n\n"
+ TOOL_PROMPT.format(tool_text=tool_text)
+ "\n\n"
+ FORMAT_PROMPT
+ "\n"
)
openai_format_tools = [get_weather_api]
system_prompt = format_prompt(openai_format_tools)
from openai import OpenAI
client = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
# List models API
model = client.models.list().data[0].id
assert model == "Arch-Function"
if not is_hallucinate_sample:
messages = [
{"role": "system", "content": system_prompt},
# {"role": "user", "content": "can you help me check weather?"},
{"role": "user", "content": "How is the weather in Seattle in 7 days?"},
# {"role": "assistant", "content": "Of course!"},
# {"role": "user", "content": "Seattle please"}
]
else:
messages = [
{"role": "system", "content": system_prompt},
# {"role": "user", "content": "can you help me check weather?"},
{"role": "user", "content": "How is the weather in Seattle in?"},
# {"role": "assistant", "content": "Of course!"},
# {"role": "user", "content": "Seattle please"}
]
extra_body = {
"temperature": 0.6,
"top_p": 1.0,
"top_k": 50,
# "continue_final_message": True,
# "add_generation_prompt": False,
"logprobs": True,
"top_logprobs": 10,
}
resp = client.chat.completions.create(
model="Arch-Function", messages=messages, extra_body=extra_body, stream=True
)
hallu = HallucinationStateHandler(
response_iterator=resp, function=function_description
)
for token in hallu:
assert len(hallu.tokens) >= 0
assert hallu.hallucination == is_hallucinate_sample