From 423cfc0872ab1cf3e81c214cb8457d29a885b562 Mon Sep 17 00:00:00 2001 From: cotran Date: Mon, 9 Dec 2024 11:33:41 -0800 Subject: [PATCH] add hallucination --- model_server/src/core/function_calling.py | 75 +++++++++++-------- ...lucination_handler.py => hallucination.py} | 25 +++++-- 2 files changed, 62 insertions(+), 38 deletions(-) rename model_server/src/core/{hallucination_handler.py => hallucination.py} (93%) diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index f652fa22..489395d1 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -13,6 +13,7 @@ from src.core.model_utils import ( ChatCompletionResponse, ArchBaseHandler, ) +from src.core.hallucination import HallucinationStateHandler class ArchIntentConfig: @@ -172,15 +173,15 @@ class ArchFunctionConfig: """ ).strip() - GENERATION_PARAMS = ( - { - "temperature": 0.2, - "top_p": 1.0, - "top_k": 50, - "max_tokens": 512, - "stop_token_ids": [151645], - }, - ) + GENERATION_PARAMS = { + "temperature": 0.2, + "top_p": 1.0, + "top_k": 50, + "max_tokens": 512, + "stop_token_ids": [151645], + "logprobs": True, + "top_logprobs": 10, + } PREFILL_CONFIG = { "prefill_params": { @@ -429,6 +430,20 @@ class ArchFunctionHandler(ArchBaseHandler): } ] + def _engage_parameter_gathering(self, messages: List[Dict[str, str]]): + """ + Engage parameter gathering for tool calls + """ + prefill_response = self.client.chat.completions.create( + messages=self._add_prefill_message(messages), + model=self.model_name, + extra_body={ + **self.generation_params, + **self.prefill_params, + }, + ) + return prefill_response + @override async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse: """ @@ -454,49 +469,47 @@ class ArchFunctionHandler(ArchBaseHandler): stream=True, extra_body=self.generation_params, ) + hallu_handler = HallucinationStateHandler( + response_iterator=response, function=req.tools + ) model_response, has_tool_call = "", None - for token in response: - token_content = token.choices[0].delta.content.strip() - - if token_content: - if has_tool_call is None and token_content != "": - has_tool_call = False - response.close() - break - else: + for token in hallu_handler: + if len(hallu_handler.tokens) > 0 and has_tool_call == False: + if hallu_handler.tokens[-0] == "": has_tool_call = True + else: + has_tool_call = False + break + if hallu_handler.hallucination == True: + prefill_response = self._engage_parameter_gathering(messages) + model_response = prefill_response.choices[0].message.content + break - if has_tool_call is True: - model_response += token_content + # start parameter gathering if the model is not generating tool calls + if hallu_handler.hallucination == False: + model_response = "".join(hallu_handler.tokens) # start parameter gathering if the model is not generating tool calls if has_tool_call is False: - prefill_response = self.client.chat.completions.create( - messages=self._add_prefill_message(messages), - model=self.model_name, - extra_body={ - **self.generation_params, - **self.prefill_params, - }, - ) + prefill_response = await self._engage_parameter_gathering(messages) model_response = prefill_response.choices[0].message.content # Extract tool calls from model response extracted = self._extract_tool_calls(model_response) - if extracted["tool_calls"]: + if extracted["result"]: # [TODO] Review: define the behavior in the case that tool call extraction fails # if not extracted["status"]: verified = self._verify_tool_calls( - tools=req.tools, tool_calls=extracted["tool_calls"] + tools=req.tools, tool_calls=extracted["result"] ) # [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately if verified["status"]: - model_response = Message(content="", tool_calls=extracted["tool_calls"]) + model_response = Message(content="", tool_calls=extracted["result"]) # else: else: diff --git a/model_server/src/core/hallucination_handler.py b/model_server/src/core/hallucination.py similarity index 93% rename from model_server/src/core/hallucination_handler.py rename to model_server/src/core/hallucination.py index 4d923ce1..38611dc1 100644 --- a/model_server/src/core/hallucination_handler.py +++ b/model_server/src/core/hallucination.py @@ -27,10 +27,10 @@ class MaskToken(Enum): HALLUCINATION_THRESHOLD_DICT = { - MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5}, + MaskToken.TOOL_CALL.value: {"entropy": 0.05, "varentropy": 0.25}, MaskToken.PARAMETER_VALUE.value: { - "entropy": 0.5, - "varentropy": 2.5, + "entropy": 0.05, + "varentropy": 0.25, }, } @@ -109,7 +109,7 @@ class HallucinationStateHandler: token_probs_map (list): List mapping tokens to their entropy and variance of entropy. """ - def __init__(self, response_iterator=None): + def __init__(self, response_iterator=None, function=None): """ Initializes the HallucinationStateHandler with default values. """ @@ -124,7 +124,19 @@ class HallucinationStateHandler: self.parameter_name: List[str] = [] self.token_probs_map: List[Tuple[str, float, float]] = [] self.response_iterator = response_iterator - self.has_tool_call = False + self._process_function(function) + + def _process_function(self, function): + self.function = function + if self.function is None: + raise ValueError("API descriptions not set.") + parameter_names = {} + for func in self.function: + func_name = func["name"] + parameters = func["parameters"]["properties"] + parameter_names[func_name] = list(parameters.keys()) + self.function_description = parameter_names + self.function_properties = {x["name"]: x["parameters"] for x in self.function} def append_and_check_token_hallucination(self, token, logprob): """ @@ -139,8 +151,7 @@ class HallucinationStateHandler: """ self.tokens.append(token) self.logprobs.append(logprob) - if self.has_tool_call: - self._process_token() + self._process_token() return self.hallucination def __iter__(self):