From e44d189d86bf8b7f942780a192e15c434ed47faa Mon Sep 17 00:00:00 2001 From: cotran Date: Mon, 25 Nov 2024 14:45:29 -0800 Subject: [PATCH] address issues --- .../function_calling/hallucination_handler.py | 2 +- model_server/app/tests/test_hallucination.py | 100 ++++++++++++++++-- 2 files changed, 95 insertions(+), 7 deletions(-) diff --git a/model_server/app/function_calling/hallucination_handler.py b/model_server/app/function_calling/hallucination_handler.py index acbda584..dc8e8ee0 100644 --- a/model_server/app/function_calling/hallucination_handler.py +++ b/model_server/app/function_calling/hallucination_handler.py @@ -212,7 +212,7 @@ class HallucinationStateHandler: if ( len(self.mask) > 1 and self.mask[-2] != "v" - and not check_parameter_property( + and not is_parameter_property( self.function_properties[self.function_name], self.parameter_name[-1], "default", diff --git a/model_server/app/tests/test_hallucination.py b/model_server/app/tests/test_hallucination.py index 25ad3303..8c23324b 100644 --- a/model_server/app/tests/test_hallucination.py +++ b/model_server/app/tests/test_hallucination.py @@ -47,14 +47,102 @@ if type(function_description) != list: @pytest.mark.parametrize("case", test_cases) def test_hallucination(case): - state = HallucinationStateHandler() - state.process_function(function_description) + state = HallucinationStateHandler( + response_iterator=None, function=function_description + ) for token, logprob in zip(case["tokens"], case["logprobs"]): if token != "": - state.current_token = token - state.tokens.append(token) - state.logprobs.append(logprob) - state.process_token() + state.check_token_hallucination(token, logprob) if state.hallucination: break assert state.hallucination == case["expect"] + + +@pytest.mark.parametrize("is_hallucinate_sample", [True, False]) +def test_hallucination_prompt(is_hallucinate_sample): + TASK_PROMPT = """ + You are a helpful assistant. + """.strip() + + TOOL_PROMPT = """ + # Tools + + You may call one or more functions to assist with the user query. + + You are provided with function signatures within XML tags: + + {tool_text} + + """.strip() + + FORMAT_PROMPT = """ + For each function call, return a json object with function name and arguments within XML tags: + + {"name": , "arguments": } + + """.strip() + + def convert_tools(tools): + return "\n".join([json.dumps(tool) for tool in tools]) + + def format_prompt(tools): + tool_text = convert_tools(tools) + + return ( + TASK_PROMPT + + "\n\n" + + TOOL_PROMPT.format(tool_text=tool_text) + + "\n\n" + + FORMAT_PROMPT + + "\n" + ) + + openai_format_tools = [get_weather_api] + + system_prompt = format_prompt(openai_format_tools) + + from openai import OpenAI + + client = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY") + + # List models API + model = client.models.list().data[0].id + assert model == "Arch-Function" + if not is_hallucinate_sample: + messages = [ + {"role": "system", "content": system_prompt}, + # {"role": "user", "content": "can you help me check weather?"}, + {"role": "user", "content": "How is the weather in Seattle in 7 days?"}, + # {"role": "assistant", "content": "Of course!"}, + # {"role": "user", "content": "Seattle please"} + ] + else: + messages = [ + {"role": "system", "content": system_prompt}, + # {"role": "user", "content": "can you help me check weather?"}, + {"role": "user", "content": "How is the weather in Seattle in?"}, + # {"role": "assistant", "content": "Of course!"}, + # {"role": "user", "content": "Seattle please"} + ] + + extra_body = { + "temperature": 0.6, + "top_p": 1.0, + "top_k": 50, + # "continue_final_message": True, + # "add_generation_prompt": False, + "logprobs": True, + "top_logprobs": 10, + } + + resp = client.chat.completions.create( + model="Arch-Function", messages=messages, extra_body=extra_body, stream=True + ) + + hallu = HallucinationStateHandler( + response_iterator=resp, function=function_description + ) + + for token in hallu: + assert len(hallu.tokens) >= 0 + assert hallu.hallucination == is_hallucinate_sample