diff --git a/model_server/app/main.py b/model_server/app/main.py
index 79d94f0d..6d1de4e7 100644
--- a/model_server/app/main.py
+++ b/model_server/app/main.py
@@ -1,7 +1,7 @@
import os
from app.commons.globals import handler_map
-from app.model_handler.function_calling import ChatMessage
+from app.model_handler.base_handler import ChatMessage
from app.model_handler.guardrails import GuardRequest
from fastapi import FastAPI, Response, Request
diff --git a/model_server/app/model_handler/base_handler.py b/model_server/app/model_handler/base_handler.py
new file mode 100644
index 00000000..6b88a2bb
--- /dev/null
+++ b/model_server/app/model_handler/base_handler.py
@@ -0,0 +1,167 @@
+import json
+
+from openai import OpenAI
+from pydantic import BaseModel
+from typing import Any, Dict, List, Optional
+from overrides import final
+
+
+class Message(BaseModel):
+ role: Optional[str] = ""
+ content: Optional[str] = ""
+ tool_call_id: Optional[str] = ""
+ tool_calls: Optional[List[Dict[str, Any]]] = []
+
+
+class ChatMessage(BaseModel):
+ messages: list[Message]
+ tools: List[Dict[str, Any]]
+
+
+class Choice(BaseModel):
+ id: Optional[int] = 0
+ message: Message
+ finish_reason: Optional[str] = "stop"
+
+
+class ChatCompletionResponse(BaseModel):
+ id: Optional[int] = 0
+ object: Optional[str] = "chat_completion"
+ created: Optional[str] = ""
+ model: str
+ choices: List[Choice]
+
+
+class ArchBaseHandler:
+ def __init__(
+ self,
+ client: OpenAI,
+ model_name: str,
+ task_prompt: str,
+ tool_prompt: str,
+ format_prompt: str,
+ generation_params: Dict,
+ ):
+ """
+ Initializes the base handler.
+
+ Args:
+ client (OpenAI): An OpenAI client instance.
+ model_name (str): Name of the model to use.
+ task_prompt (str): The main task prompt for the system.
+ tool_prompt (str): A prompt to describe tools.
+ format_prompt (str): A prompt specifying the desired output format.
+ generation_params (Dict): Generation parameters for the model.
+ """
+
+ self.client = client
+
+ self.model_name = model_name
+
+ self.task_prompt = task_prompt
+ self.tool_prompt = tool_prompt
+ self.format_prompt = format_prompt
+
+ self.generation_params = generation_params
+
+ def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
+ """
+ Converts a list of tools into the desired internal representation.
+
+ Args:
+ tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
+
+ Raises:
+ NotImplementedError: Method should be overridden in subclasses.
+ """
+
+ raise NotImplementedError()
+
+ @final
+ def _format_system(self, tools: List[Dict[str, Any]]) -> str:
+ """
+ Formats the system prompt using provided tools.
+
+ Args:
+ tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
+
+ Returns:
+ str: A formatted system prompt.
+ """
+
+ tool_text = self._convert_tools(tools)
+
+ system_prompt = (
+ self.task_prompt
+ + "\n\n"
+ + self.tool_prompt.format(tool_text=tool_text)
+ + "\n\n"
+ + self.format_prompt
+ )
+
+ return system_prompt
+
+ @final
+ def _process_messages(
+ self,
+ messages: List[Message],
+ tools: List[Dict[str, Any]] = None,
+ extra_instructions: str = None,
+ ):
+ """
+ Processes a list of messages and formats them appropriately.
+
+ Args:
+ messages (List[Message]): A list of message objects.
+ tools (List[Dict[str, Any]], optional): A list of tools to include in the system prompt.
+ extra_instructions (str, optional): Additional instructions to append to the last user message.
+
+ Returns:
+ List[Dict[str, Any]]: A list of processed message dictionaries.
+ """
+
+ processed_messages = []
+
+ if tools:
+ processed_messages.append(
+ {"role": "system", "content": self._format_system(tools)}
+ )
+
+ for message in messages:
+ role, content, tool_calls = (
+ message.role,
+ message.content,
+ message.tool_calls,
+ )
+
+ if tool_calls:
+ # [TODO] Extend to support multiple function calls
+ role = "assistant"
+ content = f"\n{json.dumps(tool_calls[0]['function'])}\n"
+ elif message.role == "tool":
+ role = "user"
+ content = (
+ f"\n{json.dumps(message.content)}\n"
+ )
+
+ processed_messages.append({"role": role, "content": content})
+
+ assert processed_messages[-1]["role"] == "user"
+
+ if extra_instructions:
+ processed_messages[-1]["content"] += extra_instructions
+
+ return processed_messages
+
+ async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
+ """
+ Abstract method for generating chat completions.
+
+ Args:
+ req (ChatMessage): A chat message request object.
+
+ Raises:
+ NotImplementedError: Method should be overridden in subclasses.
+ """
+
+ raise NotImplementedError()
diff --git a/model_server/app/model_handler/function_calling.py b/model_server/app/model_handler/function_calling.py
index 0cee18ca..da135dec 100644
--- a/model_server/app/model_handler/function_calling.py
+++ b/model_server/app/model_handler/function_calling.py
@@ -3,121 +3,20 @@ import random
import builtins
from openai import OpenAI
-from pydantic import BaseModel
-from typing import Any, Dict, List, Optional
-from overrides import override, final
+from typing import Any, Dict, List, Tuple, Union
+from overrides import override
+from app.model_handler.base_handler import (
+ Message,
+ ChatMessage,
+ Choice,
+ ChatCompletionResponse,
+ ArchBaseHandler,
+)
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
-class Message(BaseModel):
- role: Optional[str] = ""
- content: Optional[str] = ""
- tool_call_id: Optional[str] = ""
- tool_calls: Optional[List[Dict[str, Any]]] = []
-
-
-class ChatMessage(BaseModel):
- messages: list[Message]
- tools: List[Dict[str, Any]]
-
-
-class Choice(BaseModel):
- id: Optional[int] = 0
- message: Message
- finish_reason: Optional[str] = "stop"
-
-
-class ChatCompletionResponse(BaseModel):
- id: Optional[int] = 0
- object: Optional[str] = "chat_completion"
- created: Optional[str] = ""
- model: str
- choices: List[Choice]
-
-
-class ArchBaseHandler:
- def __init__(
- self,
- client: OpenAI,
- model_name: str,
- task_prompt: str,
- tool_prompt: str,
- format_prompt: str,
- generation_params: Dict,
- ):
- self.client = client
-
- self.model_name = model_name
-
- self.task_prompt = task_prompt
- self.tool_prompt = tool_prompt
- self.format_prompt = format_prompt
-
- self.generation_params = generation_params
-
- def _convert_tools(self, tools: List[Dict[str, Any]]):
- raise NotImplementedError()
-
- @final
- def _format_system(self, tools: List[Dict[str, Any]]):
- tool_text = self._convert_tools(tools)
-
- system_prompt = (
- self.task_prompt
- + "\n\n"
- + self.tool_prompt.format(tool_text=tool_text)
- + "\n\n"
- + self.format_prompt
- )
-
- return system_prompt
-
- @final
- def _process_messages(
- self,
- messages: List[Message],
- tools: List[Dict[str, Any]] = None,
- extra_instructions: str = None,
- ):
- processed_messages = []
-
- if tools:
- processed_messages.append(
- {"role": "system", "content": self._format_system(tools)}
- )
-
- for message in messages:
- role, content, tool_calls = (
- message.role,
- message.content,
- message.tool_calls,
- )
-
- if tool_calls:
- # [TODO] Extend to support multiple function calls
- role = "assistant"
- content = f"\n{json.dumps(tool_calls[0]['function'])}\n"
- elif message.role == "tool":
- role = "user"
- content = (
- f"\n{json.dumps(message.content)}\n"
- )
-
- processed_messages.append({"role": role, "content": content})
-
- assert processed_messages[-1]["role"] == "user"
-
- if extra_instructions:
- processed_messages[-1]["content"] += extra_instructions
-
- return processed_messages
-
- async def chat_completion(self, req: ChatMessage):
- raise NotImplementedError()
-
-
class ArchIntentHandler(ArchBaseHandler):
def __init__(
self,
@@ -129,6 +28,19 @@ class ArchIntentHandler(ArchBaseHandler):
intent_instruction: str,
generation_params: Dict,
):
+ """
+ Initializes the intent handler.
+
+ Args:
+ client (OpenAI): An OpenAI client instance.
+ model_name (str): Name of the model to use.
+ task_prompt (str): The main task prompt for the system.
+ tool_prompt (str): A prompt to describe tools.
+ format_prompt (str): A prompt specifying the desired output format.
+ intent_instruction (str): Instructions specific to intent handling.
+ generation_params (Dict): Generation parameters for the model.
+ """
+
super().__init__(
client,
model_name,
@@ -141,16 +53,35 @@ class ArchIntentHandler(ArchBaseHandler):
self.intent_instruction = intent_instruction
@override
- def _convert_tools(self, tools: List[Dict[str, Any]]):
+ def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
+ """
+ Converts a list of tools into a JSON-like format with indexed keys.
+
+ Args:
+ tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
+
+ Returns:
+ str: A string representation of converted tools.
+ """
+
converted = [
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
]
return "\n".join(converted)
@override
- async def chat_completion(self, req: ChatMessage):
+ async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
- Note: Currently only support vllm inference
+ Generates a chat completion for a given request.
+
+ Args:
+ req (ChatMessage): A chat message request object.
+
+ Returns:
+ ChatCompletionResponse: The model's response to the chat request.
+
+ Note:
+ Currently only support vllm inference
"""
messages = self._process_messages(
@@ -185,6 +116,20 @@ class ArchFunctionHandler(ArchBaseHandler):
prefill_params: Dict,
prefill_prefix: List,
):
+ """
+ Initializes the function handler.
+
+ Args:
+ client (OpenAI): An OpenAI client instance.
+ model_name (str): Name of the model to use.
+ task_prompt (str): The main task prompt for the system.
+ tool_prompt (str): A prompt to describe tools.
+ format_prompt (str): A prompt specifying the desired output format.
+ generation_params (Dict): Generation parameters for the model.
+ prefill_params (Dict): Additional parameters for prefilling responses.
+ prefill_prefix (List[str]): List of prefixes for prefill responses.
+ """
+
super().__init__(
client,
model_name,
@@ -204,11 +149,31 @@ class ArchFunctionHandler(ArchBaseHandler):
}
@override
- def _convert_tools(self, tools: List[Dict[str, Any]]):
+ def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
+ """
+ Converts a list of tools into JSON format.
+
+ Args:
+ tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
+
+ Returns:
+ str: A string representation of converted tools.
+ """
+
converted = [json.dumps(tool) for tool in tools]
return "\n".join(converted)
- def _fix_json_string(self, json_str: str):
+ def _fix_json_string(self, json_str: str) -> str:
+ """
+ Fixes malformed JSON strings by ensuring proper bracket matching.
+
+ Args:
+ json_str (str): A JSON string that might be malformed.
+
+ Returns:
+ str: A corrected JSON string.
+ """
+
# Remove any leading or trailing whitespace or newline characters
json_str = json_str.strip()
@@ -246,7 +211,22 @@ class ArchFunctionHandler(ArchBaseHandler):
# Attempt to parse the corrected string to ensure it’s valid JSON
return fixed_str.replace("'", '"')
- def _extract_tool_calls(self, content: str):
+ def _extract_tool_calls(
+ self, content: str
+ ) -> Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
+ """
+ Extracts tool call information from a given string.
+
+ Args:
+ content (str): The content string containing potential tool call information.
+
+ Returns:
+ Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
+ - A list of tool call dictionaries.
+ - A boolean indicating if the extraction was valid.
+ - An error message or exception if extraction failed.
+ """
+
tool_calls, is_valid, error_message = [], True, ""
flag = False
@@ -284,7 +264,21 @@ class ArchFunctionHandler(ArchBaseHandler):
def _verify_tool_calls(
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
- ):
+ ) -> Tuple[bool, Union[Dict[str, Any], None], str]:
+ """
+ Verifies the validity of extracted tool calls against the provided tools.
+
+ Args:
+ tools (List[Dict[str, Any]]): A list of available tools.
+ tool_calls (List[Dict[str, Any]]): A list of tool calls to verify.
+
+ Returns:
+ Tuple[bool, Union[Dict[str, Any], None], str]:
+ - A boolean indicating if the tool calls are valid.
+ - The invalid tool call dictionary if any.
+ - An error message.
+ """
+
is_valid, error_tool_call, error_message = True, None, ""
functions = {}
@@ -328,9 +322,21 @@ class ArchFunctionHandler(ArchBaseHandler):
return is_valid, error_tool_call, error_message
@override
- async def chat_completion(self, req: ChatMessage, enable_prefilling=True):
+ async def chat_completion(
+ self, req: ChatMessage, enable_prefilling=True
+ ) -> ChatCompletionResponse:
"""
- Note: Currently only support vllm inference
+ Generates a chat completion response for a given request.
+
+ Args:
+ req (ChatMessage): A chat message request object.
+ enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True.
+
+ Returns:
+ ChatCompletionResponse: The model's response to the chat request.
+
+ Note:
+ Currently only support vllm inference
"""
messages = self._process_messages(req.messages, req.tools)
diff --git a/model_server/app/model_handler/hallucination_handler.py b/model_server/app/model_handler/hallucination_handler.py
index 60eda200..09607db3 100644
--- a/model_server/app/model_handler/hallucination_handler.py
+++ b/model_server/app/model_handler/hallucination_handler.py
@@ -203,16 +203,16 @@ class HallucinationStateHandler:
if self.tokens[-1].strip() not in ['"', ""]:
self.mask.append(MaskToken.PARAMETER_VALUE)
- # [TODO] Review: update the following code: `is_parameter_property` should not be here
+ # [TODO] Review: update the following code: `is_parameter_property` should not be here, move to `ArchFunctionHandler`
# checking if the parameter doesn't have default and the token is the first parameter value token
if (
len(self.mask) > 1
and self.mask[-2] != MaskToken.PARAMETER_VALUE
- and not is_parameter_property(
- self.function_properties[self.function_name],
- self.parameter_name[-1],
- "default",
- )
+ # and not is_parameter_property(
+ # self.function_properties[self.function_name],
+ # self.parameter_name[-1],
+ # "default",
+ # )
):
self._check_logprob()
else:
diff --git a/model_server/app/tests/test_function_calling.py b/model_server/app/tests/test_function_calling.py
index b3c7bf2f..2df2a1f4 100644
--- a/model_server/app/tests/test_function_calling.py
+++ b/model_server/app/tests/test_function_calling.py
@@ -4,7 +4,7 @@ import pytest
from fastapi import Response
from unittest.mock import AsyncMock, MagicMock, patch
from app.commons.globals import handler_map
-from app.model_handler.function_calling import (
+from app.model_handler.base_handler import (
Message,
ChatMessage,
ChatCompletionResponse,