mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Refine model_server
This commit is contained in:
parent
a5bd005411
commit
4fcfd83639
6 changed files with 149 additions and 64 deletions
|
|
@ -23,7 +23,7 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt: str,
|
||||
tool_prompt_template: str,
|
||||
format_prompt: str,
|
||||
extra_instruction: str,
|
||||
generation_params: Dict,
|
||||
|
|
@ -35,7 +35,7 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
task_prompt (str): The main task prompt for the system.
|
||||
tool_prompt (str): A prompt to describe tools.
|
||||
tool_prompt_template (str): A prompt to describe tools.
|
||||
format_prompt (str): A prompt specifying the desired output format.
|
||||
extra_instruction (str): Instructions specific to intent handling.
|
||||
generation_params (Dict): Generation parameters for the model.
|
||||
|
|
@ -45,7 +45,7 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
client,
|
||||
model_name,
|
||||
task_prompt,
|
||||
tool_prompt,
|
||||
tool_prompt_template,
|
||||
format_prompt,
|
||||
generation_params,
|
||||
)
|
||||
|
|
@ -69,6 +69,19 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
]
|
||||
return "\n".join(converted)
|
||||
|
||||
def detect_intent(self, content: str) -> bool:
|
||||
"""
|
||||
Detect if any intent match with prompts
|
||||
|
||||
Args:
|
||||
content: str: Model response that contains intent detection results
|
||||
|
||||
Returns:
|
||||
bool: A boolean value to indicate if any intent match with prompts or not
|
||||
"""
|
||||
|
||||
return content.choices[0].message.content == "Yes"
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
"""
|
||||
|
|
@ -110,7 +123,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt: str,
|
||||
tool_prompt_template: str,
|
||||
format_prompt: str,
|
||||
generation_params: Dict,
|
||||
prefill_params: Dict,
|
||||
|
|
@ -123,7 +136,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
task_prompt (str): The main task prompt for the system.
|
||||
tool_prompt (str): A prompt to describe tools.
|
||||
tool_prompt_template (str): A prompt to describe tools.
|
||||
format_prompt (str): A prompt specifying the desired output format.
|
||||
generation_params (Dict): Generation parameters for the model.
|
||||
prefill_params (Dict): Additional parameters for prefilling responses.
|
||||
|
|
@ -134,7 +147,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
client,
|
||||
model_name,
|
||||
task_prompt,
|
||||
tool_prompt,
|
||||
tool_prompt_template,
|
||||
format_prompt,
|
||||
generation_params,
|
||||
)
|
||||
|
|
@ -392,15 +405,24 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
else:
|
||||
model_response = response.choices[0].message.content
|
||||
|
||||
tool_calls, is_valid, error_message = self._extract_tool_calls(model_response)
|
||||
(
|
||||
tool_calls,
|
||||
extraction_status,
|
||||
extraction_error_message,
|
||||
) = self._extract_tool_calls(model_response)
|
||||
|
||||
if tool_calls:
|
||||
is_valid, error_tool_call, error_message = self._verify_tool_calls(
|
||||
tools=req.tools, tool_calls=tool_calls
|
||||
)
|
||||
# [TODO] Review: define the behavior in the case that tool call extraction fails
|
||||
# if not extraction_status:
|
||||
|
||||
(
|
||||
verification_status,
|
||||
invalid_tool_call,
|
||||
verification_error_message,
|
||||
) = self._verify_tool_calls(tools=req.tools, tool_calls=tool_calls)
|
||||
|
||||
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
|
||||
if is_valid:
|
||||
if verification_status:
|
||||
model_response = Message(content="", tool_calls=tool_calls)
|
||||
# else:
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue