Fix prompt prefilling

This commit is contained in:
Shuguang Chen 2025-03-31 15:08:38 -07:00
parent afe7cc9e9e
commit 6ec4c14407

View file

@ -48,22 +48,6 @@ class ArchFunctionConfig:
"top_logprobs": 10,
}
PREFILL_CONFIG = {
"prefill_params": {
"continue_final_message": True,
"add_generation_prompt": False,
},
"prefill_prefix": [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
],
}
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
@ -91,8 +75,13 @@ class ArchFunctionHandler(ArchBaseHandler):
config.GENERATION_PARAMS,
)
self.prefill_params = config.PREFILL_CONFIG["prefill_params"]
self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"]
self.generation_params = self.generation_params | {
"continue_final_message": True,
"add_generation_prompt": False,
}
self.default_prefix = '```json\n{"'
self.clarify_prefix = '```json\n{"required_functions":'
self.hallucination_state = None
@ -163,8 +152,7 @@ class ArchFunctionHandler(ArchBaseHandler):
unmatched_opening = stack.pop()
fixed_str += opening_bracket[unmatched_opening]
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str.replace("'", '"')
return fixed_str
def _parse_model_resonse(self, content: str) -> Dict[str, any]:
"""
@ -197,7 +185,11 @@ class ArchFunctionHandler(ArchBaseHandler):
if content.startswith("json"):
content = content[4:].strip()
model_response = json.loads(self._fix_json_string(content))
content = self._fix_json_string(content)
try:
model_response = json.loads(content)
except Exception:
model_response = json.loads(content.replace("'", '"'))
response_dict["response"] = model_response.get("response", "")
response_dict["required_functions"] = model_response.get(
@ -325,7 +317,7 @@ class ArchFunctionHandler(ArchBaseHandler):
return verification_dict
def _add_prefill_message(self, messages: List[Dict[str, str]]):
def _prefill_message(self, messages: List[Dict[str, str]], prefill_message):
"""
Update messages and generation params for prompt prefilling
@ -335,29 +327,7 @@ class ArchFunctionHandler(ArchBaseHandler):
Returns:
prefill_messages (List[Dict[str, str]]): A list of messages.
"""
return messages + [
{
"role": "assistant",
"content": random.choice(self.prefill_prefix),
}
]
def _engage_parameter_gathering(self, messages: List[Dict[str, str]]):
"""
Engage parameter gathering for tool calls
"""
# TODO: log enaging parameter gathering
prefill_response = self.client.chat.completions.create(
messages=self._add_prefill_message(messages),
model=self.model_name,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
return prefill_response
return messages + [{"role": "assistant", "content": prefill_message}]
@override
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
@ -385,7 +355,7 @@ class ArchFunctionHandler(ArchBaseHandler):
# always enable `stream=True` to collect model responses
response = self.client.chat.completions.create(
messages=messages,
messages=self._prefill_message(messages, self.default_prefix),
model=self.model_name,
stream=True,
extra_body=self.generation_params,
@ -415,16 +385,13 @@ class ArchFunctionHandler(ArchBaseHandler):
has_tool_calls, has_hallucination = None, False
for _ in self.hallucination_state:
# check if the first token is <tool_call>
content = "".join(self.hallucination_state.tokens)
if "tool_calls" in content:
logger.info(
f"[Content]: {content}"
)
has_tool_calls = True
else:
has_tool_calls = False
# check if moodel response starts with tool calls
if has_tool_calls is None:
content = "".join(self.hallucination_state.tokens)
if "tool_calls" in content:
has_tool_calls = True
else:
has_tool_calls = False
# if the model is hallucinating, start parameter gathering
if self.hallucination_state.hallucination is True:
@ -436,10 +403,19 @@ class ArchFunctionHandler(ArchBaseHandler):
logger.info(
f"[Hallucination]: {self.hallucination_state.error_message}"
)
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
response = self.client.chat.completions.create(
messages=self._prefill_message(messages, self.clarify_prefix),
model=self.model_name,
stream=False,
extra_body=self.generation_params,
)
model_response = (
self.clarify_prefix + response.choices[0].message.content
)
else:
model_response = "".join(self.hallucination_state.tokens)
model_response = self.default_prefix + "".join(
self.hallucination_state.tokens
)
# else:
# # start parameter gathering if the model is not generating tool calls
# prefill_response = self._engage_parameter_gathering(messages)