diff --git a/model_server/app/main.py b/model_server/app/main.py index 79d94f0d..6d1de4e7 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -1,7 +1,7 @@ import os from app.commons.globals import handler_map -from app.model_handler.function_calling import ChatMessage +from app.model_handler.base_handler import ChatMessage from app.model_handler.guardrails import GuardRequest from fastapi import FastAPI, Response, Request diff --git a/model_server/app/model_handler/base_handler.py b/model_server/app/model_handler/base_handler.py new file mode 100644 index 00000000..6b88a2bb --- /dev/null +++ b/model_server/app/model_handler/base_handler.py @@ -0,0 +1,167 @@ +import json + +from openai import OpenAI +from pydantic import BaseModel +from typing import Any, Dict, List, Optional +from overrides import final + + +class Message(BaseModel): + role: Optional[str] = "" + content: Optional[str] = "" + tool_call_id: Optional[str] = "" + tool_calls: Optional[List[Dict[str, Any]]] = [] + + +class ChatMessage(BaseModel): + messages: list[Message] + tools: List[Dict[str, Any]] + + +class Choice(BaseModel): + id: Optional[int] = 0 + message: Message + finish_reason: Optional[str] = "stop" + + +class ChatCompletionResponse(BaseModel): + id: Optional[int] = 0 + object: Optional[str] = "chat_completion" + created: Optional[str] = "" + model: str + choices: List[Choice] + + +class ArchBaseHandler: + def __init__( + self, + client: OpenAI, + model_name: str, + task_prompt: str, + tool_prompt: str, + format_prompt: str, + generation_params: Dict, + ): + """ + Initializes the base handler. + + Args: + client (OpenAI): An OpenAI client instance. + model_name (str): Name of the model to use. + task_prompt (str): The main task prompt for the system. + tool_prompt (str): A prompt to describe tools. + format_prompt (str): A prompt specifying the desired output format. + generation_params (Dict): Generation parameters for the model. + """ + + self.client = client + + self.model_name = model_name + + self.task_prompt = task_prompt + self.tool_prompt = tool_prompt + self.format_prompt = format_prompt + + self.generation_params = generation_params + + def _convert_tools(self, tools: List[Dict[str, Any]]) -> str: + """ + Converts a list of tools into the desired internal representation. + + Args: + tools (List[Dict[str, Any]]): A list of tools represented as dictionaries. + + Raises: + NotImplementedError: Method should be overridden in subclasses. + """ + + raise NotImplementedError() + + @final + def _format_system(self, tools: List[Dict[str, Any]]) -> str: + """ + Formats the system prompt using provided tools. + + Args: + tools (List[Dict[str, Any]]): A list of tools represented as dictionaries. + + Returns: + str: A formatted system prompt. + """ + + tool_text = self._convert_tools(tools) + + system_prompt = ( + self.task_prompt + + "\n\n" + + self.tool_prompt.format(tool_text=tool_text) + + "\n\n" + + self.format_prompt + ) + + return system_prompt + + @final + def _process_messages( + self, + messages: List[Message], + tools: List[Dict[str, Any]] = None, + extra_instructions: str = None, + ): + """ + Processes a list of messages and formats them appropriately. + + Args: + messages (List[Message]): A list of message objects. + tools (List[Dict[str, Any]], optional): A list of tools to include in the system prompt. + extra_instructions (str, optional): Additional instructions to append to the last user message. + + Returns: + List[Dict[str, Any]]: A list of processed message dictionaries. + """ + + processed_messages = [] + + if tools: + processed_messages.append( + {"role": "system", "content": self._format_system(tools)} + ) + + for message in messages: + role, content, tool_calls = ( + message.role, + message.content, + message.tool_calls, + ) + + if tool_calls: + # [TODO] Extend to support multiple function calls + role = "assistant" + content = f"\n{json.dumps(tool_calls[0]['function'])}\n" + elif message.role == "tool": + role = "user" + content = ( + f"\n{json.dumps(message.content)}\n" + ) + + processed_messages.append({"role": role, "content": content}) + + assert processed_messages[-1]["role"] == "user" + + if extra_instructions: + processed_messages[-1]["content"] += extra_instructions + + return processed_messages + + async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse: + """ + Abstract method for generating chat completions. + + Args: + req (ChatMessage): A chat message request object. + + Raises: + NotImplementedError: Method should be overridden in subclasses. + """ + + raise NotImplementedError() diff --git a/model_server/app/model_handler/function_calling.py b/model_server/app/model_handler/function_calling.py index 0cee18ca..da135dec 100644 --- a/model_server/app/model_handler/function_calling.py +++ b/model_server/app/model_handler/function_calling.py @@ -3,121 +3,20 @@ import random import builtins from openai import OpenAI -from pydantic import BaseModel -from typing import Any, Dict, List, Optional -from overrides import override, final +from typing import Any, Dict, List, Tuple, Union +from overrides import override +from app.model_handler.base_handler import ( + Message, + ChatMessage, + Choice, + ChatCompletionResponse, + ArchBaseHandler, +) SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"] -class Message(BaseModel): - role: Optional[str] = "" - content: Optional[str] = "" - tool_call_id: Optional[str] = "" - tool_calls: Optional[List[Dict[str, Any]]] = [] - - -class ChatMessage(BaseModel): - messages: list[Message] - tools: List[Dict[str, Any]] - - -class Choice(BaseModel): - id: Optional[int] = 0 - message: Message - finish_reason: Optional[str] = "stop" - - -class ChatCompletionResponse(BaseModel): - id: Optional[int] = 0 - object: Optional[str] = "chat_completion" - created: Optional[str] = "" - model: str - choices: List[Choice] - - -class ArchBaseHandler: - def __init__( - self, - client: OpenAI, - model_name: str, - task_prompt: str, - tool_prompt: str, - format_prompt: str, - generation_params: Dict, - ): - self.client = client - - self.model_name = model_name - - self.task_prompt = task_prompt - self.tool_prompt = tool_prompt - self.format_prompt = format_prompt - - self.generation_params = generation_params - - def _convert_tools(self, tools: List[Dict[str, Any]]): - raise NotImplementedError() - - @final - def _format_system(self, tools: List[Dict[str, Any]]): - tool_text = self._convert_tools(tools) - - system_prompt = ( - self.task_prompt - + "\n\n" - + self.tool_prompt.format(tool_text=tool_text) - + "\n\n" - + self.format_prompt - ) - - return system_prompt - - @final - def _process_messages( - self, - messages: List[Message], - tools: List[Dict[str, Any]] = None, - extra_instructions: str = None, - ): - processed_messages = [] - - if tools: - processed_messages.append( - {"role": "system", "content": self._format_system(tools)} - ) - - for message in messages: - role, content, tool_calls = ( - message.role, - message.content, - message.tool_calls, - ) - - if tool_calls: - # [TODO] Extend to support multiple function calls - role = "assistant" - content = f"\n{json.dumps(tool_calls[0]['function'])}\n" - elif message.role == "tool": - role = "user" - content = ( - f"\n{json.dumps(message.content)}\n" - ) - - processed_messages.append({"role": role, "content": content}) - - assert processed_messages[-1]["role"] == "user" - - if extra_instructions: - processed_messages[-1]["content"] += extra_instructions - - return processed_messages - - async def chat_completion(self, req: ChatMessage): - raise NotImplementedError() - - class ArchIntentHandler(ArchBaseHandler): def __init__( self, @@ -129,6 +28,19 @@ class ArchIntentHandler(ArchBaseHandler): intent_instruction: str, generation_params: Dict, ): + """ + Initializes the intent handler. + + Args: + client (OpenAI): An OpenAI client instance. + model_name (str): Name of the model to use. + task_prompt (str): The main task prompt for the system. + tool_prompt (str): A prompt to describe tools. + format_prompt (str): A prompt specifying the desired output format. + intent_instruction (str): Instructions specific to intent handling. + generation_params (Dict): Generation parameters for the model. + """ + super().__init__( client, model_name, @@ -141,16 +53,35 @@ class ArchIntentHandler(ArchBaseHandler): self.intent_instruction = intent_instruction @override - def _convert_tools(self, tools: List[Dict[str, Any]]): + def _convert_tools(self, tools: List[Dict[str, Any]]) -> str: + """ + Converts a list of tools into a JSON-like format with indexed keys. + + Args: + tools (List[Dict[str, Any]]): A list of tools represented as dictionaries. + + Returns: + str: A string representation of converted tools. + """ + converted = [ json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools) ] return "\n".join(converted) @override - async def chat_completion(self, req: ChatMessage): + async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse: """ - Note: Currently only support vllm inference + Generates a chat completion for a given request. + + Args: + req (ChatMessage): A chat message request object. + + Returns: + ChatCompletionResponse: The model's response to the chat request. + + Note: + Currently only support vllm inference """ messages = self._process_messages( @@ -185,6 +116,20 @@ class ArchFunctionHandler(ArchBaseHandler): prefill_params: Dict, prefill_prefix: List, ): + """ + Initializes the function handler. + + Args: + client (OpenAI): An OpenAI client instance. + model_name (str): Name of the model to use. + task_prompt (str): The main task prompt for the system. + tool_prompt (str): A prompt to describe tools. + format_prompt (str): A prompt specifying the desired output format. + generation_params (Dict): Generation parameters for the model. + prefill_params (Dict): Additional parameters for prefilling responses. + prefill_prefix (List[str]): List of prefixes for prefill responses. + """ + super().__init__( client, model_name, @@ -204,11 +149,31 @@ class ArchFunctionHandler(ArchBaseHandler): } @override - def _convert_tools(self, tools: List[Dict[str, Any]]): + def _convert_tools(self, tools: List[Dict[str, Any]]) -> str: + """ + Converts a list of tools into JSON format. + + Args: + tools (List[Dict[str, Any]]): A list of tools represented as dictionaries. + + Returns: + str: A string representation of converted tools. + """ + converted = [json.dumps(tool) for tool in tools] return "\n".join(converted) - def _fix_json_string(self, json_str: str): + def _fix_json_string(self, json_str: str) -> str: + """ + Fixes malformed JSON strings by ensuring proper bracket matching. + + Args: + json_str (str): A JSON string that might be malformed. + + Returns: + str: A corrected JSON string. + """ + # Remove any leading or trailing whitespace or newline characters json_str = json_str.strip() @@ -246,7 +211,22 @@ class ArchFunctionHandler(ArchBaseHandler): # Attempt to parse the corrected string to ensure it’s valid JSON return fixed_str.replace("'", '"') - def _extract_tool_calls(self, content: str): + def _extract_tool_calls( + self, content: str + ) -> Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]: + """ + Extracts tool call information from a given string. + + Args: + content (str): The content string containing potential tool call information. + + Returns: + Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]: + - A list of tool call dictionaries. + - A boolean indicating if the extraction was valid. + - An error message or exception if extraction failed. + """ + tool_calls, is_valid, error_message = [], True, "" flag = False @@ -284,7 +264,21 @@ class ArchFunctionHandler(ArchBaseHandler): def _verify_tool_calls( self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]] - ): + ) -> Tuple[bool, Union[Dict[str, Any], None], str]: + """ + Verifies the validity of extracted tool calls against the provided tools. + + Args: + tools (List[Dict[str, Any]]): A list of available tools. + tool_calls (List[Dict[str, Any]]): A list of tool calls to verify. + + Returns: + Tuple[bool, Union[Dict[str, Any], None], str]: + - A boolean indicating if the tool calls are valid. + - The invalid tool call dictionary if any. + - An error message. + """ + is_valid, error_tool_call, error_message = True, None, "" functions = {} @@ -328,9 +322,21 @@ class ArchFunctionHandler(ArchBaseHandler): return is_valid, error_tool_call, error_message @override - async def chat_completion(self, req: ChatMessage, enable_prefilling=True): + async def chat_completion( + self, req: ChatMessage, enable_prefilling=True + ) -> ChatCompletionResponse: """ - Note: Currently only support vllm inference + Generates a chat completion response for a given request. + + Args: + req (ChatMessage): A chat message request object. + enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True. + + Returns: + ChatCompletionResponse: The model's response to the chat request. + + Note: + Currently only support vllm inference """ messages = self._process_messages(req.messages, req.tools) diff --git a/model_server/app/model_handler/hallucination_handler.py b/model_server/app/model_handler/hallucination_handler.py index 60eda200..09607db3 100644 --- a/model_server/app/model_handler/hallucination_handler.py +++ b/model_server/app/model_handler/hallucination_handler.py @@ -203,16 +203,16 @@ class HallucinationStateHandler: if self.tokens[-1].strip() not in ['"', ""]: self.mask.append(MaskToken.PARAMETER_VALUE) - # [TODO] Review: update the following code: `is_parameter_property` should not be here + # [TODO] Review: update the following code: `is_parameter_property` should not be here, move to `ArchFunctionHandler` # checking if the parameter doesn't have default and the token is the first parameter value token if ( len(self.mask) > 1 and self.mask[-2] != MaskToken.PARAMETER_VALUE - and not is_parameter_property( - self.function_properties[self.function_name], - self.parameter_name[-1], - "default", - ) + # and not is_parameter_property( + # self.function_properties[self.function_name], + # self.parameter_name[-1], + # "default", + # ) ): self._check_logprob() else: diff --git a/model_server/app/tests/test_function_calling.py b/model_server/app/tests/test_function_calling.py index b3c7bf2f..2df2a1f4 100644 --- a/model_server/app/tests/test_function_calling.py +++ b/model_server/app/tests/test_function_calling.py @@ -4,7 +4,7 @@ import pytest from fastapi import Response from unittest.mock import AsyncMock, MagicMock, patch from app.commons.globals import handler_map -from app.model_handler.function_calling import ( +from app.model_handler.base_handler import ( Message, ChatMessage, ChatCompletionResponse,