refactor model_handler

This commit is contained in:
Shuguang Chen 2024-12-05 11:00:22 -08:00
parent afe1410b37
commit b686cf8b87
5 changed files with 300 additions and 127 deletions

View file

@ -1,7 +1,7 @@
import os
from app.commons.globals import handler_map
from app.model_handler.function_calling import ChatMessage
from app.model_handler.base_handler import ChatMessage
from app.model_handler.guardrails import GuardRequest
from fastapi import FastAPI, Response, Request

View file

@ -0,0 +1,167 @@
import json
from openai import OpenAI
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from overrides import final
class Message(BaseModel):
role: Optional[str] = ""
content: Optional[str] = ""
tool_call_id: Optional[str] = ""
tool_calls: Optional[List[Dict[str, Any]]] = []
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
class Choice(BaseModel):
id: Optional[int] = 0
message: Message
finish_reason: Optional[str] = "stop"
class ChatCompletionResponse(BaseModel):
id: Optional[int] = 0
object: Optional[str] = "chat_completion"
created: Optional[str] = ""
model: str
choices: List[Choice]
class ArchBaseHandler:
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt: str,
format_prompt: str,
generation_params: Dict,
):
"""
Initializes the base handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
"""
self.client = client
self.model_name = model_name
self.task_prompt = task_prompt
self.tool_prompt = tool_prompt
self.format_prompt = format_prompt
self.generation_params = generation_params
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into the desired internal representation.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Raises:
NotImplementedError: Method should be overridden in subclasses.
"""
raise NotImplementedError()
@final
def _format_system(self, tools: List[Dict[str, Any]]) -> str:
"""
Formats the system prompt using provided tools.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A formatted system prompt.
"""
tool_text = self._convert_tools(tools)
system_prompt = (
self.task_prompt
+ "\n\n"
+ self.tool_prompt.format(tool_text=tool_text)
+ "\n\n"
+ self.format_prompt
)
return system_prompt
@final
def _process_messages(
self,
messages: List[Message],
tools: List[Dict[str, Any]] = None,
extra_instructions: str = None,
):
"""
Processes a list of messages and formats them appropriately.
Args:
messages (List[Message]): A list of message objects.
tools (List[Dict[str, Any]], optional): A list of tools to include in the system prompt.
extra_instructions (str, optional): Additional instructions to append to the last user message.
Returns:
List[Dict[str, Any]]: A list of processed message dictionaries.
"""
processed_messages = []
if tools:
processed_messages.append(
{"role": "system", "content": self._format_system(tools)}
)
for message in messages:
role, content, tool_calls = (
message.role,
message.content,
message.tool_calls,
)
if tool_calls:
# [TODO] Extend to support multiple function calls
role = "assistant"
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
elif message.role == "tool":
role = "user"
content = (
f"<tool_response>\n{json.dumps(message.content)}\n</tool_response>"
)
processed_messages.append({"role": role, "content": content})
assert processed_messages[-1]["role"] == "user"
if extra_instructions:
processed_messages[-1]["content"] += extra_instructions
return processed_messages
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
Abstract method for generating chat completions.
Args:
req (ChatMessage): A chat message request object.
Raises:
NotImplementedError: Method should be overridden in subclasses.
"""
raise NotImplementedError()

View file

@ -3,121 +3,20 @@ import random
import builtins
from openai import OpenAI
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from overrides import override, final
from typing import Any, Dict, List, Tuple, Union
from overrides import override
from app.model_handler.base_handler import (
Message,
ChatMessage,
Choice,
ChatCompletionResponse,
ArchBaseHandler,
)
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
class Message(BaseModel):
role: Optional[str] = ""
content: Optional[str] = ""
tool_call_id: Optional[str] = ""
tool_calls: Optional[List[Dict[str, Any]]] = []
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
class Choice(BaseModel):
id: Optional[int] = 0
message: Message
finish_reason: Optional[str] = "stop"
class ChatCompletionResponse(BaseModel):
id: Optional[int] = 0
object: Optional[str] = "chat_completion"
created: Optional[str] = ""
model: str
choices: List[Choice]
class ArchBaseHandler:
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt: str,
format_prompt: str,
generation_params: Dict,
):
self.client = client
self.model_name = model_name
self.task_prompt = task_prompt
self.tool_prompt = tool_prompt
self.format_prompt = format_prompt
self.generation_params = generation_params
def _convert_tools(self, tools: List[Dict[str, Any]]):
raise NotImplementedError()
@final
def _format_system(self, tools: List[Dict[str, Any]]):
tool_text = self._convert_tools(tools)
system_prompt = (
self.task_prompt
+ "\n\n"
+ self.tool_prompt.format(tool_text=tool_text)
+ "\n\n"
+ self.format_prompt
)
return system_prompt
@final
def _process_messages(
self,
messages: List[Message],
tools: List[Dict[str, Any]] = None,
extra_instructions: str = None,
):
processed_messages = []
if tools:
processed_messages.append(
{"role": "system", "content": self._format_system(tools)}
)
for message in messages:
role, content, tool_calls = (
message.role,
message.content,
message.tool_calls,
)
if tool_calls:
# [TODO] Extend to support multiple function calls
role = "assistant"
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
elif message.role == "tool":
role = "user"
content = (
f"<tool_response>\n{json.dumps(message.content)}\n</tool_response>"
)
processed_messages.append({"role": role, "content": content})
assert processed_messages[-1]["role"] == "user"
if extra_instructions:
processed_messages[-1]["content"] += extra_instructions
return processed_messages
async def chat_completion(self, req: ChatMessage):
raise NotImplementedError()
class ArchIntentHandler(ArchBaseHandler):
def __init__(
self,
@ -129,6 +28,19 @@ class ArchIntentHandler(ArchBaseHandler):
intent_instruction: str,
generation_params: Dict,
):
"""
Initializes the intent handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
intent_instruction (str): Instructions specific to intent handling.
generation_params (Dict): Generation parameters for the model.
"""
super().__init__(
client,
model_name,
@ -141,16 +53,35 @@ class ArchIntentHandler(ArchBaseHandler):
self.intent_instruction = intent_instruction
@override
def _convert_tools(self, tools: List[Dict[str, Any]]):
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into a JSON-like format with indexed keys.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A string representation of converted tools.
"""
converted = [
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
]
return "\n".join(converted)
@override
async def chat_completion(self, req: ChatMessage):
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
Note: Currently only support vllm inference
Generates a chat completion for a given request.
Args:
req (ChatMessage): A chat message request object.
Returns:
ChatCompletionResponse: The model's response to the chat request.
Note:
Currently only support vllm inference
"""
messages = self._process_messages(
@ -185,6 +116,20 @@ class ArchFunctionHandler(ArchBaseHandler):
prefill_params: Dict,
prefill_prefix: List,
):
"""
Initializes the function handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
prefill_params (Dict): Additional parameters for prefilling responses.
prefill_prefix (List[str]): List of prefixes for prefill responses.
"""
super().__init__(
client,
model_name,
@ -204,11 +149,31 @@ class ArchFunctionHandler(ArchBaseHandler):
}
@override
def _convert_tools(self, tools: List[Dict[str, Any]]):
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into JSON format.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A string representation of converted tools.
"""
converted = [json.dumps(tool) for tool in tools]
return "\n".join(converted)
def _fix_json_string(self, json_str: str):
def _fix_json_string(self, json_str: str) -> str:
"""
Fixes malformed JSON strings by ensuring proper bracket matching.
Args:
json_str (str): A JSON string that might be malformed.
Returns:
str: A corrected JSON string.
"""
# Remove any leading or trailing whitespace or newline characters
json_str = json_str.strip()
@ -246,7 +211,22 @@ class ArchFunctionHandler(ArchBaseHandler):
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str.replace("'", '"')
def _extract_tool_calls(self, content: str):
def _extract_tool_calls(
self, content: str
) -> Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
"""
Extracts tool call information from a given string.
Args:
content (str): The content string containing potential tool call information.
Returns:
Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
- A list of tool call dictionaries.
- A boolean indicating if the extraction was valid.
- An error message or exception if extraction failed.
"""
tool_calls, is_valid, error_message = [], True, ""
flag = False
@ -284,7 +264,21 @@ class ArchFunctionHandler(ArchBaseHandler):
def _verify_tool_calls(
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
):
) -> Tuple[bool, Union[Dict[str, Any], None], str]:
"""
Verifies the validity of extracted tool calls against the provided tools.
Args:
tools (List[Dict[str, Any]]): A list of available tools.
tool_calls (List[Dict[str, Any]]): A list of tool calls to verify.
Returns:
Tuple[bool, Union[Dict[str, Any], None], str]:
- A boolean indicating if the tool calls are valid.
- The invalid tool call dictionary if any.
- An error message.
"""
is_valid, error_tool_call, error_message = True, None, ""
functions = {}
@ -328,9 +322,21 @@ class ArchFunctionHandler(ArchBaseHandler):
return is_valid, error_tool_call, error_message
@override
async def chat_completion(self, req: ChatMessage, enable_prefilling=True):
async def chat_completion(
self, req: ChatMessage, enable_prefilling=True
) -> ChatCompletionResponse:
"""
Note: Currently only support vllm inference
Generates a chat completion response for a given request.
Args:
req (ChatMessage): A chat message request object.
enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True.
Returns:
ChatCompletionResponse: The model's response to the chat request.
Note:
Currently only support vllm inference
"""
messages = self._process_messages(req.messages, req.tools)

View file

@ -203,16 +203,16 @@ class HallucinationStateHandler:
if self.tokens[-1].strip() not in ['"', ""]:
self.mask.append(MaskToken.PARAMETER_VALUE)
# [TODO] Review: update the following code: `is_parameter_property` should not be here
# [TODO] Review: update the following code: `is_parameter_property` should not be here, move to `ArchFunctionHandler`
# checking if the parameter doesn't have default and the token is the first parameter value token
if (
len(self.mask) > 1
and self.mask[-2] != MaskToken.PARAMETER_VALUE
and not is_parameter_property(
self.function_properties[self.function_name],
self.parameter_name[-1],
"default",
)
# and not is_parameter_property(
# self.function_properties[self.function_name],
# self.parameter_name[-1],
# "default",
# )
):
self._check_logprob()
else:

View file

@ -4,7 +4,7 @@ import pytest
from fastapi import Response
from unittest.mock import AsyncMock, MagicMock, patch
from app.commons.globals import handler_map
from app.model_handler.function_calling import (
from app.model_handler.base_handler import (
Message,
ChatMessage,
ChatCompletionResponse,