From 6ec4c144075afb956a2a42cf1a3e2bdd7dc5657a Mon Sep 17 00:00:00 2001 From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com> Date: Mon, 31 Mar 2025 15:08:38 -0700 Subject: [PATCH] Fix prompt prefilling --- model_server/src/core/function_calling.py | 94 +++++++++-------------- 1 file changed, 35 insertions(+), 59 deletions(-) diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 13a77876..781149d5 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -48,22 +48,6 @@ class ArchFunctionConfig: "top_logprobs": 10, } - PREFILL_CONFIG = { - "prefill_params": { - "continue_final_message": True, - "add_generation_prompt": False, - }, - "prefill_prefix": [ - "May", - "Could", - "Sure", - "Definitely", - "Certainly", - "Of course", - "Can", - ], - } - SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"] @@ -91,8 +75,13 @@ class ArchFunctionHandler(ArchBaseHandler): config.GENERATION_PARAMS, ) - self.prefill_params = config.PREFILL_CONFIG["prefill_params"] - self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"] + self.generation_params = self.generation_params | { + "continue_final_message": True, + "add_generation_prompt": False, + } + + self.default_prefix = '```json\n{"' + self.clarify_prefix = '```json\n{"required_functions":' self.hallucination_state = None @@ -163,8 +152,7 @@ class ArchFunctionHandler(ArchBaseHandler): unmatched_opening = stack.pop() fixed_str += opening_bracket[unmatched_opening] - # Attempt to parse the corrected string to ensure it’s valid JSON - return fixed_str.replace("'", '"') + return fixed_str def _parse_model_resonse(self, content: str) -> Dict[str, any]: """ @@ -197,7 +185,11 @@ class ArchFunctionHandler(ArchBaseHandler): if content.startswith("json"): content = content[4:].strip() - model_response = json.loads(self._fix_json_string(content)) + content = self._fix_json_string(content) + try: + model_response = json.loads(content) + except Exception: + model_response = json.loads(content.replace("'", '"')) response_dict["response"] = model_response.get("response", "") response_dict["required_functions"] = model_response.get( @@ -325,7 +317,7 @@ class ArchFunctionHandler(ArchBaseHandler): return verification_dict - def _add_prefill_message(self, messages: List[Dict[str, str]]): + def _prefill_message(self, messages: List[Dict[str, str]], prefill_message): """ Update messages and generation params for prompt prefilling @@ -335,29 +327,7 @@ class ArchFunctionHandler(ArchBaseHandler): Returns: prefill_messages (List[Dict[str, str]]): A list of messages. """ - - return messages + [ - { - "role": "assistant", - "content": random.choice(self.prefill_prefix), - } - ] - - def _engage_parameter_gathering(self, messages: List[Dict[str, str]]): - """ - Engage parameter gathering for tool calls - """ - - # TODO: log enaging parameter gathering - prefill_response = self.client.chat.completions.create( - messages=self._add_prefill_message(messages), - model=self.model_name, - extra_body={ - **self.generation_params, - **self.prefill_params, - }, - ) - return prefill_response + return messages + [{"role": "assistant", "content": prefill_message}] @override async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse: @@ -385,7 +355,7 @@ class ArchFunctionHandler(ArchBaseHandler): # always enable `stream=True` to collect model responses response = self.client.chat.completions.create( - messages=messages, + messages=self._prefill_message(messages, self.default_prefix), model=self.model_name, stream=True, extra_body=self.generation_params, @@ -415,16 +385,13 @@ class ArchFunctionHandler(ArchBaseHandler): has_tool_calls, has_hallucination = None, False for _ in self.hallucination_state: - # check if the first token is - content = "".join(self.hallucination_state.tokens) - if "tool_calls" in content: - logger.info( - f"[Content]: {content}" - ) - has_tool_calls = True - else: - has_tool_calls = False - + # check if moodel response starts with tool calls + if has_tool_calls is None: + content = "".join(self.hallucination_state.tokens) + if "tool_calls" in content: + has_tool_calls = True + else: + has_tool_calls = False # if the model is hallucinating, start parameter gathering if self.hallucination_state.hallucination is True: @@ -436,10 +403,19 @@ class ArchFunctionHandler(ArchBaseHandler): logger.info( f"[Hallucination]: {self.hallucination_state.error_message}" ) - prefill_response = self._engage_parameter_gathering(messages) - model_response = prefill_response.choices[0].message.content + response = self.client.chat.completions.create( + messages=self._prefill_message(messages, self.clarify_prefix), + model=self.model_name, + stream=False, + extra_body=self.generation_params, + ) + model_response = ( + self.clarify_prefix + response.choices[0].message.content + ) else: - model_response = "".join(self.hallucination_state.tokens) + model_response = self.default_prefix + "".join( + self.hallucination_state.tokens + ) # else: # # start parameter gathering if the model is not generating tool calls # prefill_response = self._engage_parameter_gathering(messages)