mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Fix prompt prefilling
This commit is contained in:
parent
afe7cc9e9e
commit
6ec4c14407
1 changed files with 35 additions and 59 deletions
|
|
@ -48,22 +48,6 @@ class ArchFunctionConfig:
|
|||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
PREFILL_CONFIG = {
|
||||
"prefill_params": {
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False,
|
||||
},
|
||||
"prefill_prefix": [
|
||||
"May",
|
||||
"Could",
|
||||
"Sure",
|
||||
"Definitely",
|
||||
"Certainly",
|
||||
"Of course",
|
||||
"Can",
|
||||
],
|
||||
}
|
||||
|
||||
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
||||
|
||||
|
||||
|
|
@ -91,8 +75,13 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
config.GENERATION_PARAMS,
|
||||
)
|
||||
|
||||
self.prefill_params = config.PREFILL_CONFIG["prefill_params"]
|
||||
self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"]
|
||||
self.generation_params = self.generation_params | {
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False,
|
||||
}
|
||||
|
||||
self.default_prefix = '```json\n{"'
|
||||
self.clarify_prefix = '```json\n{"required_functions":'
|
||||
|
||||
self.hallucination_state = None
|
||||
|
||||
|
|
@ -163,8 +152,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
unmatched_opening = stack.pop()
|
||||
fixed_str += opening_bracket[unmatched_opening]
|
||||
|
||||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str.replace("'", '"')
|
||||
return fixed_str
|
||||
|
||||
def _parse_model_resonse(self, content: str) -> Dict[str, any]:
|
||||
"""
|
||||
|
|
@ -197,7 +185,11 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
if content.startswith("json"):
|
||||
content = content[4:].strip()
|
||||
|
||||
model_response = json.loads(self._fix_json_string(content))
|
||||
content = self._fix_json_string(content)
|
||||
try:
|
||||
model_response = json.loads(content)
|
||||
except Exception:
|
||||
model_response = json.loads(content.replace("'", '"'))
|
||||
|
||||
response_dict["response"] = model_response.get("response", "")
|
||||
response_dict["required_functions"] = model_response.get(
|
||||
|
|
@ -325,7 +317,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
return verification_dict
|
||||
|
||||
def _add_prefill_message(self, messages: List[Dict[str, str]]):
|
||||
def _prefill_message(self, messages: List[Dict[str, str]], prefill_message):
|
||||
"""
|
||||
Update messages and generation params for prompt prefilling
|
||||
|
||||
|
|
@ -335,29 +327,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
Returns:
|
||||
prefill_messages (List[Dict[str, str]]): A list of messages.
|
||||
"""
|
||||
|
||||
return messages + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": random.choice(self.prefill_prefix),
|
||||
}
|
||||
]
|
||||
|
||||
def _engage_parameter_gathering(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
Engage parameter gathering for tool calls
|
||||
"""
|
||||
|
||||
# TODO: log enaging parameter gathering
|
||||
prefill_response = self.client.chat.completions.create(
|
||||
messages=self._add_prefill_message(messages),
|
||||
model=self.model_name,
|
||||
extra_body={
|
||||
**self.generation_params,
|
||||
**self.prefill_params,
|
||||
},
|
||||
)
|
||||
return prefill_response
|
||||
return messages + [{"role": "assistant", "content": prefill_message}]
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
|
|
@ -385,7 +355,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
# always enable `stream=True` to collect model responses
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
messages=self._prefill_message(messages, self.default_prefix),
|
||||
model=self.model_name,
|
||||
stream=True,
|
||||
extra_body=self.generation_params,
|
||||
|
|
@ -415,16 +385,13 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
has_tool_calls, has_hallucination = None, False
|
||||
for _ in self.hallucination_state:
|
||||
# check if the first token is <tool_call>
|
||||
content = "".join(self.hallucination_state.tokens)
|
||||
if "tool_calls" in content:
|
||||
logger.info(
|
||||
f"[Content]: {content}"
|
||||
)
|
||||
has_tool_calls = True
|
||||
else:
|
||||
has_tool_calls = False
|
||||
|
||||
# check if moodel response starts with tool calls
|
||||
if has_tool_calls is None:
|
||||
content = "".join(self.hallucination_state.tokens)
|
||||
if "tool_calls" in content:
|
||||
has_tool_calls = True
|
||||
else:
|
||||
has_tool_calls = False
|
||||
|
||||
# if the model is hallucinating, start parameter gathering
|
||||
if self.hallucination_state.hallucination is True:
|
||||
|
|
@ -436,10 +403,19 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
logger.info(
|
||||
f"[Hallucination]: {self.hallucination_state.error_message}"
|
||||
)
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
response = self.client.chat.completions.create(
|
||||
messages=self._prefill_message(messages, self.clarify_prefix),
|
||||
model=self.model_name,
|
||||
stream=False,
|
||||
extra_body=self.generation_params,
|
||||
)
|
||||
model_response = (
|
||||
self.clarify_prefix + response.choices[0].message.content
|
||||
)
|
||||
else:
|
||||
model_response = "".join(self.hallucination_state.tokens)
|
||||
model_response = self.default_prefix + "".join(
|
||||
self.hallucination_state.tokens
|
||||
)
|
||||
# else:
|
||||
# # start parameter gathering if the model is not generating tool calls
|
||||
# prefill_response = self._engage_parameter_gathering(messages)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue