mirror of
https://github.com/katanemo/plano.git
synced 2026-06-26 15:39:40 +02:00
refactor model_handler
This commit is contained in:
parent
afe1410b37
commit
b686cf8b87
5 changed files with 300 additions and 127 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from app.commons.globals import handler_map
|
from app.commons.globals import handler_map
|
||||||
from app.model_handler.function_calling import ChatMessage
|
from app.model_handler.base_handler import ChatMessage
|
||||||
from app.model_handler.guardrails import GuardRequest
|
from app.model_handler.guardrails import GuardRequest
|
||||||
|
|
||||||
from fastapi import FastAPI, Response, Request
|
from fastapi import FastAPI, Response, Request
|
||||||
|
|
|
||||||
167
model_server/app/model_handler/base_handler.py
Normal file
167
model_server/app/model_handler/base_handler.py
Normal file
|
|
@ -0,0 +1,167 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from overrides import final
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
role: Optional[str] = ""
|
||||||
|
content: Optional[str] = ""
|
||||||
|
tool_call_id: Optional[str] = ""
|
||||||
|
tool_calls: Optional[List[Dict[str, Any]]] = []
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
messages: list[Message]
|
||||||
|
tools: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class Choice(BaseModel):
|
||||||
|
id: Optional[int] = 0
|
||||||
|
message: Message
|
||||||
|
finish_reason: Optional[str] = "stop"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
id: Optional[int] = 0
|
||||||
|
object: Optional[str] = "chat_completion"
|
||||||
|
created: Optional[str] = ""
|
||||||
|
model: str
|
||||||
|
choices: List[Choice]
|
||||||
|
|
||||||
|
|
||||||
|
class ArchBaseHandler:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: OpenAI,
|
||||||
|
model_name: str,
|
||||||
|
task_prompt: str,
|
||||||
|
tool_prompt: str,
|
||||||
|
format_prompt: str,
|
||||||
|
generation_params: Dict,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the base handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client (OpenAI): An OpenAI client instance.
|
||||||
|
model_name (str): Name of the model to use.
|
||||||
|
task_prompt (str): The main task prompt for the system.
|
||||||
|
tool_prompt (str): A prompt to describe tools.
|
||||||
|
format_prompt (str): A prompt specifying the desired output format.
|
||||||
|
generation_params (Dict): Generation parameters for the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
self.task_prompt = task_prompt
|
||||||
|
self.tool_prompt = tool_prompt
|
||||||
|
self.format_prompt = format_prompt
|
||||||
|
|
||||||
|
self.generation_params = generation_params
|
||||||
|
|
||||||
|
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||||
|
"""
|
||||||
|
Converts a list of tools into the desired internal representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: Method should be overridden in subclasses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@final
|
||||||
|
def _format_system(self, tools: List[Dict[str, Any]]) -> str:
|
||||||
|
"""
|
||||||
|
Formats the system prompt using provided tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A formatted system prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_text = self._convert_tools(tools)
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
self.task_prompt
|
||||||
|
+ "\n\n"
|
||||||
|
+ self.tool_prompt.format(tool_text=tool_text)
|
||||||
|
+ "\n\n"
|
||||||
|
+ self.format_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
return system_prompt
|
||||||
|
|
||||||
|
@final
|
||||||
|
def _process_messages(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
tools: List[Dict[str, Any]] = None,
|
||||||
|
extra_instructions: str = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Processes a list of messages and formats them appropriately.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (List[Message]): A list of message objects.
|
||||||
|
tools (List[Dict[str, Any]], optional): A list of tools to include in the system prompt.
|
||||||
|
extra_instructions (str, optional): Additional instructions to append to the last user message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: A list of processed message dictionaries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
processed_messages = []
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
processed_messages.append(
|
||||||
|
{"role": "system", "content": self._format_system(tools)}
|
||||||
|
)
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
role, content, tool_calls = (
|
||||||
|
message.role,
|
||||||
|
message.content,
|
||||||
|
message.tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
# [TODO] Extend to support multiple function calls
|
||||||
|
role = "assistant"
|
||||||
|
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
|
||||||
|
elif message.role == "tool":
|
||||||
|
role = "user"
|
||||||
|
content = (
|
||||||
|
f"<tool_response>\n{json.dumps(message.content)}\n</tool_response>"
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_messages.append({"role": role, "content": content})
|
||||||
|
|
||||||
|
assert processed_messages[-1]["role"] == "user"
|
||||||
|
|
||||||
|
if extra_instructions:
|
||||||
|
processed_messages[-1]["content"] += extra_instructions
|
||||||
|
|
||||||
|
return processed_messages
|
||||||
|
|
||||||
|
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||||
|
"""
|
||||||
|
Abstract method for generating chat completions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req (ChatMessage): A chat message request object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: Method should be overridden in subclasses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
@ -3,121 +3,20 @@ import random
|
||||||
import builtins
|
import builtins
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from pydantic import BaseModel
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
from typing import Any, Dict, List, Optional
|
from overrides import override
|
||||||
from overrides import override, final
|
from app.model_handler.base_handler import (
|
||||||
|
Message,
|
||||||
|
ChatMessage,
|
||||||
|
Choice,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ArchBaseHandler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
|
||||||
role: Optional[str] = ""
|
|
||||||
content: Optional[str] = ""
|
|
||||||
tool_call_id: Optional[str] = ""
|
|
||||||
tool_calls: Optional[List[Dict[str, Any]]] = []
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
|
||||||
messages: list[Message]
|
|
||||||
tools: List[Dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class Choice(BaseModel):
|
|
||||||
id: Optional[int] = 0
|
|
||||||
message: Message
|
|
||||||
finish_reason: Optional[str] = "stop"
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
|
||||||
id: Optional[int] = 0
|
|
||||||
object: Optional[str] = "chat_completion"
|
|
||||||
created: Optional[str] = ""
|
|
||||||
model: str
|
|
||||||
choices: List[Choice]
|
|
||||||
|
|
||||||
|
|
||||||
class ArchBaseHandler:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
client: OpenAI,
|
|
||||||
model_name: str,
|
|
||||||
task_prompt: str,
|
|
||||||
tool_prompt: str,
|
|
||||||
format_prompt: str,
|
|
||||||
generation_params: Dict,
|
|
||||||
):
|
|
||||||
self.client = client
|
|
||||||
|
|
||||||
self.model_name = model_name
|
|
||||||
|
|
||||||
self.task_prompt = task_prompt
|
|
||||||
self.tool_prompt = tool_prompt
|
|
||||||
self.format_prompt = format_prompt
|
|
||||||
|
|
||||||
self.generation_params = generation_params
|
|
||||||
|
|
||||||
def _convert_tools(self, tools: List[Dict[str, Any]]):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@final
|
|
||||||
def _format_system(self, tools: List[Dict[str, Any]]):
|
|
||||||
tool_text = self._convert_tools(tools)
|
|
||||||
|
|
||||||
system_prompt = (
|
|
||||||
self.task_prompt
|
|
||||||
+ "\n\n"
|
|
||||||
+ self.tool_prompt.format(tool_text=tool_text)
|
|
||||||
+ "\n\n"
|
|
||||||
+ self.format_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
return system_prompt
|
|
||||||
|
|
||||||
@final
|
|
||||||
def _process_messages(
|
|
||||||
self,
|
|
||||||
messages: List[Message],
|
|
||||||
tools: List[Dict[str, Any]] = None,
|
|
||||||
extra_instructions: str = None,
|
|
||||||
):
|
|
||||||
processed_messages = []
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
processed_messages.append(
|
|
||||||
{"role": "system", "content": self._format_system(tools)}
|
|
||||||
)
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
role, content, tool_calls = (
|
|
||||||
message.role,
|
|
||||||
message.content,
|
|
||||||
message.tool_calls,
|
|
||||||
)
|
|
||||||
|
|
||||||
if tool_calls:
|
|
||||||
# [TODO] Extend to support multiple function calls
|
|
||||||
role = "assistant"
|
|
||||||
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
|
|
||||||
elif message.role == "tool":
|
|
||||||
role = "user"
|
|
||||||
content = (
|
|
||||||
f"<tool_response>\n{json.dumps(message.content)}\n</tool_response>"
|
|
||||||
)
|
|
||||||
|
|
||||||
processed_messages.append({"role": role, "content": content})
|
|
||||||
|
|
||||||
assert processed_messages[-1]["role"] == "user"
|
|
||||||
|
|
||||||
if extra_instructions:
|
|
||||||
processed_messages[-1]["content"] += extra_instructions
|
|
||||||
|
|
||||||
return processed_messages
|
|
||||||
|
|
||||||
async def chat_completion(self, req: ChatMessage):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class ArchIntentHandler(ArchBaseHandler):
|
class ArchIntentHandler(ArchBaseHandler):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -129,6 +28,19 @@ class ArchIntentHandler(ArchBaseHandler):
|
||||||
intent_instruction: str,
|
intent_instruction: str,
|
||||||
generation_params: Dict,
|
generation_params: Dict,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initializes the intent handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client (OpenAI): An OpenAI client instance.
|
||||||
|
model_name (str): Name of the model to use.
|
||||||
|
task_prompt (str): The main task prompt for the system.
|
||||||
|
tool_prompt (str): A prompt to describe tools.
|
||||||
|
format_prompt (str): A prompt specifying the desired output format.
|
||||||
|
intent_instruction (str): Instructions specific to intent handling.
|
||||||
|
generation_params (Dict): Generation parameters for the model.
|
||||||
|
"""
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
client,
|
client,
|
||||||
model_name,
|
model_name,
|
||||||
|
|
@ -141,16 +53,35 @@ class ArchIntentHandler(ArchBaseHandler):
|
||||||
self.intent_instruction = intent_instruction
|
self.intent_instruction = intent_instruction
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _convert_tools(self, tools: List[Dict[str, Any]]):
|
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||||
|
"""
|
||||||
|
Converts a list of tools into a JSON-like format with indexed keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A string representation of converted tools.
|
||||||
|
"""
|
||||||
|
|
||||||
converted = [
|
converted = [
|
||||||
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
|
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
|
||||||
]
|
]
|
||||||
return "\n".join(converted)
|
return "\n".join(converted)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def chat_completion(self, req: ChatMessage):
|
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||||
"""
|
"""
|
||||||
Note: Currently only support vllm inference
|
Generates a chat completion for a given request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req (ChatMessage): A chat message request object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatCompletionResponse: The model's response to the chat request.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Currently only support vllm inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages = self._process_messages(
|
messages = self._process_messages(
|
||||||
|
|
@ -185,6 +116,20 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
prefill_params: Dict,
|
prefill_params: Dict,
|
||||||
prefill_prefix: List,
|
prefill_prefix: List,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initializes the function handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client (OpenAI): An OpenAI client instance.
|
||||||
|
model_name (str): Name of the model to use.
|
||||||
|
task_prompt (str): The main task prompt for the system.
|
||||||
|
tool_prompt (str): A prompt to describe tools.
|
||||||
|
format_prompt (str): A prompt specifying the desired output format.
|
||||||
|
generation_params (Dict): Generation parameters for the model.
|
||||||
|
prefill_params (Dict): Additional parameters for prefilling responses.
|
||||||
|
prefill_prefix (List[str]): List of prefixes for prefill responses.
|
||||||
|
"""
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
client,
|
client,
|
||||||
model_name,
|
model_name,
|
||||||
|
|
@ -204,11 +149,31 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
}
|
}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _convert_tools(self, tools: List[Dict[str, Any]]):
|
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||||
|
"""
|
||||||
|
Converts a list of tools into JSON format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A string representation of converted tools.
|
||||||
|
"""
|
||||||
|
|
||||||
converted = [json.dumps(tool) for tool in tools]
|
converted = [json.dumps(tool) for tool in tools]
|
||||||
return "\n".join(converted)
|
return "\n".join(converted)
|
||||||
|
|
||||||
def _fix_json_string(self, json_str: str):
|
def _fix_json_string(self, json_str: str) -> str:
|
||||||
|
"""
|
||||||
|
Fixes malformed JSON strings by ensuring proper bracket matching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_str (str): A JSON string that might be malformed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A corrected JSON string.
|
||||||
|
"""
|
||||||
|
|
||||||
# Remove any leading or trailing whitespace or newline characters
|
# Remove any leading or trailing whitespace or newline characters
|
||||||
json_str = json_str.strip()
|
json_str = json_str.strip()
|
||||||
|
|
||||||
|
|
@ -246,7 +211,22 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||||
return fixed_str.replace("'", '"')
|
return fixed_str.replace("'", '"')
|
||||||
|
|
||||||
def _extract_tool_calls(self, content: str):
|
def _extract_tool_calls(
|
||||||
|
self, content: str
|
||||||
|
) -> Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
|
||||||
|
"""
|
||||||
|
Extracts tool call information from a given string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (str): The content string containing potential tool call information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
|
||||||
|
- A list of tool call dictionaries.
|
||||||
|
- A boolean indicating if the extraction was valid.
|
||||||
|
- An error message or exception if extraction failed.
|
||||||
|
"""
|
||||||
|
|
||||||
tool_calls, is_valid, error_message = [], True, ""
|
tool_calls, is_valid, error_message = [], True, ""
|
||||||
|
|
||||||
flag = False
|
flag = False
|
||||||
|
|
@ -284,7 +264,21 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
|
|
||||||
def _verify_tool_calls(
|
def _verify_tool_calls(
|
||||||
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
|
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
|
||||||
):
|
) -> Tuple[bool, Union[Dict[str, Any], None], str]:
|
||||||
|
"""
|
||||||
|
Verifies the validity of extracted tool calls against the provided tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools (List[Dict[str, Any]]): A list of available tools.
|
||||||
|
tool_calls (List[Dict[str, Any]]): A list of tool calls to verify.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, Union[Dict[str, Any], None], str]:
|
||||||
|
- A boolean indicating if the tool calls are valid.
|
||||||
|
- The invalid tool call dictionary if any.
|
||||||
|
- An error message.
|
||||||
|
"""
|
||||||
|
|
||||||
is_valid, error_tool_call, error_message = True, None, ""
|
is_valid, error_tool_call, error_message = True, None, ""
|
||||||
|
|
||||||
functions = {}
|
functions = {}
|
||||||
|
|
@ -328,9 +322,21 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
return is_valid, error_tool_call, error_message
|
return is_valid, error_tool_call, error_message
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def chat_completion(self, req: ChatMessage, enable_prefilling=True):
|
async def chat_completion(
|
||||||
|
self, req: ChatMessage, enable_prefilling=True
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
"""
|
"""
|
||||||
Note: Currently only support vllm inference
|
Generates a chat completion response for a given request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req (ChatMessage): A chat message request object.
|
||||||
|
enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatCompletionResponse: The model's response to the chat request.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Currently only support vllm inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages = self._process_messages(req.messages, req.tools)
|
messages = self._process_messages(req.messages, req.tools)
|
||||||
|
|
|
||||||
|
|
@ -203,16 +203,16 @@ class HallucinationStateHandler:
|
||||||
if self.tokens[-1].strip() not in ['"', ""]:
|
if self.tokens[-1].strip() not in ['"', ""]:
|
||||||
self.mask.append(MaskToken.PARAMETER_VALUE)
|
self.mask.append(MaskToken.PARAMETER_VALUE)
|
||||||
|
|
||||||
# [TODO] Review: update the following code: `is_parameter_property` should not be here
|
# [TODO] Review: update the following code: `is_parameter_property` should not be here, move to `ArchFunctionHandler`
|
||||||
# checking if the parameter doesn't have default and the token is the first parameter value token
|
# checking if the parameter doesn't have default and the token is the first parameter value token
|
||||||
if (
|
if (
|
||||||
len(self.mask) > 1
|
len(self.mask) > 1
|
||||||
and self.mask[-2] != MaskToken.PARAMETER_VALUE
|
and self.mask[-2] != MaskToken.PARAMETER_VALUE
|
||||||
and not is_parameter_property(
|
# and not is_parameter_property(
|
||||||
self.function_properties[self.function_name],
|
# self.function_properties[self.function_name],
|
||||||
self.parameter_name[-1],
|
# self.parameter_name[-1],
|
||||||
"default",
|
# "default",
|
||||||
)
|
# )
|
||||||
):
|
):
|
||||||
self._check_logprob()
|
self._check_logprob()
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import pytest
|
||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
from app.commons.globals import handler_map
|
from app.commons.globals import handler_map
|
||||||
from app.model_handler.function_calling import (
|
from app.model_handler.base_handler import (
|
||||||
Message,
|
Message,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue