mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
refactor model_handler
This commit is contained in:
parent
afe1410b37
commit
b686cf8b87
5 changed files with 300 additions and 127 deletions
|
|
@ -3,121 +3,20 @@ import random
|
|||
import builtins
|
||||
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional
|
||||
from overrides import override, final
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from overrides import override
|
||||
from app.model_handler.base_handler import (
|
||||
Message,
|
||||
ChatMessage,
|
||||
Choice,
|
||||
ChatCompletionResponse,
|
||||
ArchBaseHandler,
|
||||
)
|
||||
|
||||
|
||||
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Optional[str] = ""
|
||||
content: Optional[str] = ""
|
||||
tool_call_id: Optional[str] = ""
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = []
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
messages: list[Message]
|
||||
tools: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
id: Optional[int] = 0
|
||||
message: Message
|
||||
finish_reason: Optional[str] = "stop"
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Optional[int] = 0
|
||||
object: Optional[str] = "chat_completion"
|
||||
created: Optional[str] = ""
|
||||
model: str
|
||||
choices: List[Choice]
|
||||
|
||||
|
||||
class ArchBaseHandler:
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt: str,
|
||||
format_prompt: str,
|
||||
generation_params: Dict,
|
||||
):
|
||||
self.client = client
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
self.task_prompt = task_prompt
|
||||
self.tool_prompt = tool_prompt
|
||||
self.format_prompt = format_prompt
|
||||
|
||||
self.generation_params = generation_params
|
||||
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]):
|
||||
raise NotImplementedError()
|
||||
|
||||
@final
|
||||
def _format_system(self, tools: List[Dict[str, Any]]):
|
||||
tool_text = self._convert_tools(tools)
|
||||
|
||||
system_prompt = (
|
||||
self.task_prompt
|
||||
+ "\n\n"
|
||||
+ self.tool_prompt.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ self.format_prompt
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
@final
|
||||
def _process_messages(
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: List[Dict[str, Any]] = None,
|
||||
extra_instructions: str = None,
|
||||
):
|
||||
processed_messages = []
|
||||
|
||||
if tools:
|
||||
processed_messages.append(
|
||||
{"role": "system", "content": self._format_system(tools)}
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
role, content, tool_calls = (
|
||||
message.role,
|
||||
message.content,
|
||||
message.tool_calls,
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
# [TODO] Extend to support multiple function calls
|
||||
role = "assistant"
|
||||
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
|
||||
elif message.role == "tool":
|
||||
role = "user"
|
||||
content = (
|
||||
f"<tool_response>\n{json.dumps(message.content)}\n</tool_response>"
|
||||
)
|
||||
|
||||
processed_messages.append({"role": role, "content": content})
|
||||
|
||||
assert processed_messages[-1]["role"] == "user"
|
||||
|
||||
if extra_instructions:
|
||||
processed_messages[-1]["content"] += extra_instructions
|
||||
|
||||
return processed_messages
|
||||
|
||||
async def chat_completion(self, req: ChatMessage):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ArchIntentHandler(ArchBaseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -129,6 +28,19 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
intent_instruction: str,
|
||||
generation_params: Dict,
|
||||
):
|
||||
"""
|
||||
Initializes the intent handler.
|
||||
|
||||
Args:
|
||||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
task_prompt (str): The main task prompt for the system.
|
||||
tool_prompt (str): A prompt to describe tools.
|
||||
format_prompt (str): A prompt specifying the desired output format.
|
||||
intent_instruction (str): Instructions specific to intent handling.
|
||||
generation_params (Dict): Generation parameters for the model.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
client,
|
||||
model_name,
|
||||
|
|
@ -141,16 +53,35 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
self.intent_instruction = intent_instruction
|
||||
|
||||
@override
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]):
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Converts a list of tools into a JSON-like format with indexed keys.
|
||||
|
||||
Args:
|
||||
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||
|
||||
Returns:
|
||||
str: A string representation of converted tools.
|
||||
"""
|
||||
|
||||
converted = [
|
||||
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
|
||||
]
|
||||
return "\n".join(converted)
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage):
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
"""
|
||||
Note: Currently only support vllm inference
|
||||
Generates a chat completion for a given request.
|
||||
|
||||
Args:
|
||||
req (ChatMessage): A chat message request object.
|
||||
|
||||
Returns:
|
||||
ChatCompletionResponse: The model's response to the chat request.
|
||||
|
||||
Note:
|
||||
Currently only support vllm inference
|
||||
"""
|
||||
|
||||
messages = self._process_messages(
|
||||
|
|
@ -185,6 +116,20 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
prefill_params: Dict,
|
||||
prefill_prefix: List,
|
||||
):
|
||||
"""
|
||||
Initializes the function handler.
|
||||
|
||||
Args:
|
||||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
task_prompt (str): The main task prompt for the system.
|
||||
tool_prompt (str): A prompt to describe tools.
|
||||
format_prompt (str): A prompt specifying the desired output format.
|
||||
generation_params (Dict): Generation parameters for the model.
|
||||
prefill_params (Dict): Additional parameters for prefilling responses.
|
||||
prefill_prefix (List[str]): List of prefixes for prefill responses.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
client,
|
||||
model_name,
|
||||
|
|
@ -204,11 +149,31 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
}
|
||||
|
||||
@override
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]):
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Converts a list of tools into JSON format.
|
||||
|
||||
Args:
|
||||
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||
|
||||
Returns:
|
||||
str: A string representation of converted tools.
|
||||
"""
|
||||
|
||||
converted = [json.dumps(tool) for tool in tools]
|
||||
return "\n".join(converted)
|
||||
|
||||
def _fix_json_string(self, json_str: str):
|
||||
def _fix_json_string(self, json_str: str) -> str:
|
||||
"""
|
||||
Fixes malformed JSON strings by ensuring proper bracket matching.
|
||||
|
||||
Args:
|
||||
json_str (str): A JSON string that might be malformed.
|
||||
|
||||
Returns:
|
||||
str: A corrected JSON string.
|
||||
"""
|
||||
|
||||
# Remove any leading or trailing whitespace or newline characters
|
||||
json_str = json_str.strip()
|
||||
|
||||
|
|
@ -246,7 +211,22 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str.replace("'", '"')
|
||||
|
||||
def _extract_tool_calls(self, content: str):
|
||||
def _extract_tool_calls(
|
||||
self, content: str
|
||||
) -> Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
|
||||
"""
|
||||
Extracts tool call information from a given string.
|
||||
|
||||
Args:
|
||||
content (str): The content string containing potential tool call information.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
|
||||
- A list of tool call dictionaries.
|
||||
- A boolean indicating if the extraction was valid.
|
||||
- An error message or exception if extraction failed.
|
||||
"""
|
||||
|
||||
tool_calls, is_valid, error_message = [], True, ""
|
||||
|
||||
flag = False
|
||||
|
|
@ -284,7 +264,21 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
def _verify_tool_calls(
|
||||
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
|
||||
):
|
||||
) -> Tuple[bool, Union[Dict[str, Any], None], str]:
|
||||
"""
|
||||
Verifies the validity of extracted tool calls against the provided tools.
|
||||
|
||||
Args:
|
||||
tools (List[Dict[str, Any]]): A list of available tools.
|
||||
tool_calls (List[Dict[str, Any]]): A list of tool calls to verify.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict[str, Any], None], str]:
|
||||
- A boolean indicating if the tool calls are valid.
|
||||
- The invalid tool call dictionary if any.
|
||||
- An error message.
|
||||
"""
|
||||
|
||||
is_valid, error_tool_call, error_message = True, None, ""
|
||||
|
||||
functions = {}
|
||||
|
|
@ -328,9 +322,21 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
return is_valid, error_tool_call, error_message
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage, enable_prefilling=True):
|
||||
async def chat_completion(
|
||||
self, req: ChatMessage, enable_prefilling=True
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Note: Currently only support vllm inference
|
||||
Generates a chat completion response for a given request.
|
||||
|
||||
Args:
|
||||
req (ChatMessage): A chat message request object.
|
||||
enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True.
|
||||
|
||||
Returns:
|
||||
ChatCompletionResponse: The model's response to the chat request.
|
||||
|
||||
Note:
|
||||
Currently only support vllm inference
|
||||
"""
|
||||
|
||||
messages = self._process_messages(req.messages, req.tools)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue