diff --git a/model_server/app/function_calling/hallucination_handler.py b/model_server/app/function_calling/hallucination_handler.py
index acbda584..dc8e8ee0 100644
--- a/model_server/app/function_calling/hallucination_handler.py
+++ b/model_server/app/function_calling/hallucination_handler.py
@@ -212,7 +212,7 @@ class HallucinationStateHandler:
if (
len(self.mask) > 1
and self.mask[-2] != "v"
- and not check_parameter_property(
+ and not is_parameter_property(
self.function_properties[self.function_name],
self.parameter_name[-1],
"default",
diff --git a/model_server/app/tests/test_hallucination.py b/model_server/app/tests/test_hallucination.py
index 25ad3303..8c23324b 100644
--- a/model_server/app/tests/test_hallucination.py
+++ b/model_server/app/tests/test_hallucination.py
@@ -47,14 +47,102 @@ if type(function_description) != list:
@pytest.mark.parametrize("case", test_cases)
def test_hallucination(case):
- state = HallucinationStateHandler()
- state.process_function(function_description)
+ state = HallucinationStateHandler(
+ response_iterator=None, function=function_description
+ )
for token, logprob in zip(case["tokens"], case["logprobs"]):
if token != "":
- state.current_token = token
- state.tokens.append(token)
- state.logprobs.append(logprob)
- state.process_token()
+ state.check_token_hallucination(token, logprob)
if state.hallucination:
break
assert state.hallucination == case["expect"]
+
+
+@pytest.mark.parametrize("is_hallucinate_sample", [True, False])
+def test_hallucination_prompt(is_hallucinate_sample):
+ TASK_PROMPT = """
+ You are a helpful assistant.
+ """.strip()
+
+ TOOL_PROMPT = """
+ # Tools
+
+ You may call one or more functions to assist with the user query.
+
+ You are provided with function signatures within XML tags:
+
+ {tool_text}
+
+ """.strip()
+
+ FORMAT_PROMPT = """
+ For each function call, return a json object with function name and arguments within XML tags:
+
+ {"name": , "arguments": }
+
+ """.strip()
+
+ def convert_tools(tools):
+ return "\n".join([json.dumps(tool) for tool in tools])
+
+ def format_prompt(tools):
+ tool_text = convert_tools(tools)
+
+ return (
+ TASK_PROMPT
+ + "\n\n"
+ + TOOL_PROMPT.format(tool_text=tool_text)
+ + "\n\n"
+ + FORMAT_PROMPT
+ + "\n"
+ )
+
+ openai_format_tools = [get_weather_api]
+
+ system_prompt = format_prompt(openai_format_tools)
+
+ from openai import OpenAI
+
+ client = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
+
+ # List models API
+ model = client.models.list().data[0].id
+ assert model == "Arch-Function"
+ if not is_hallucinate_sample:
+ messages = [
+ {"role": "system", "content": system_prompt},
+ # {"role": "user", "content": "can you help me check weather?"},
+ {"role": "user", "content": "How is the weather in Seattle in 7 days?"},
+ # {"role": "assistant", "content": "Of course!"},
+ # {"role": "user", "content": "Seattle please"}
+ ]
+ else:
+ messages = [
+ {"role": "system", "content": system_prompt},
+ # {"role": "user", "content": "can you help me check weather?"},
+ {"role": "user", "content": "How is the weather in Seattle in?"},
+ # {"role": "assistant", "content": "Of course!"},
+ # {"role": "user", "content": "Seattle please"}
+ ]
+
+ extra_body = {
+ "temperature": 0.6,
+ "top_p": 1.0,
+ "top_k": 50,
+ # "continue_final_message": True,
+ # "add_generation_prompt": False,
+ "logprobs": True,
+ "top_logprobs": 10,
+ }
+
+ resp = client.chat.completions.create(
+ model="Arch-Function", messages=messages, extra_body=extra_body, stream=True
+ )
+
+ hallu = HallucinationStateHandler(
+ response_iterator=resp, function=function_description
+ )
+
+ for token in hallu:
+ assert len(hallu.tokens) >= 0
+ assert hallu.hallucination == is_hallucinate_sample