diff --git a/model_server/src/cli.py b/model_server/src/cli.py
index d863d028..8d032961 100644
--- a/model_server/src/cli.py
+++ b/model_server/src/cli.py
@@ -8,8 +8,8 @@ from src.commons.utils import (
wait_for_health_check,
check_lsof,
install_lsof,
- find_process_by_port,
- kill_process_by_port,
+ find_processes_by_port,
+ kill_processes,
)
@@ -23,7 +23,7 @@ def start_server(port=51000):
"python",
"-m",
"uvicorn",
- "app.main:app",
+ "src.main:app",
"--host",
"0.0.0.0",
"--port",
@@ -56,14 +56,16 @@ def stop_server(port=51000, wait=True, timeout=10):
sys.exit(1)
logger.info(f"Stopping processes on port {port}...")
- port_processes = find_process_by_port(port)
+ port_processes = find_processes_by_port(port)
if port_processes is None:
logger.info(f"No processes found listening on port {port}.")
else:
if len(port_processes):
- process_killed = kill_process_by_port(port_processes, wait, timeout)
+ process_killed = kill_processes(port_processes, wait, timeout)
if not process_killed:
logger.error(f"Unable to kill all processes on {port}")
+ else:
+ logger.info(f"All processes on port {port} have been killed.")
else:
logger.error(f"Unable to find processes on {port}")
diff --git a/model_server/src/commons/constants.py b/model_server/src/commons/constants.py
deleted file mode 100644
index 1b9035e0..00000000
--- a/model_server/src/commons/constants.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# ========================== Arch-Intent Default Params ==========================
-ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
-ARCH_INTENT_INSTRUCTION = "Are there any tools can help?"
-
-ARCH_INTENT_TASK_PROMPT = """
-You are a helpful assistant.
-"""
-
-
-ARCH_INTENT_TOOL_PROMPT_TEMPLATE = """
-You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
-
-
-{tool_text}
-
-"""
-
-
-ARCH_INTENT_FORMAT_PROMPT = """
-Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
-- First line must read 'Yes' or 'No'.
-- If yes, a second line must include a comma-separated list of tool indexes.
-"""
-
-
-ARCH_INTENT_GENERATION_CONFIG = {
- "generation_params": {"max_tokens": 1, "stop_token_ids": [151645]}
-}
-
-
-# ========================== Arch-Function Default Params ==========================
-ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
-
-ARCH_FUNCTION_TASK_PROMPT = """
-You are a helpful assistant.
-"""
-
-
-ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE = """
-# Tools
-
-You may call one or more functions to assist with the user query.
-
-You are provided with function signatures within XML tags:
-
-{tool_text}
-
-"""
-
-
-ARCH_FUNCTION_FORMAT_PROMPT = """
-For each function call, return a json object with function name and arguments within XML tags:
-
-{"name": , "arguments": }
-
-"""
-
-ARCH_FUNCTION_GENERATION_CONFIG = {
- "generation_params": {
- "temperature": 0.2,
- "top_p": 1.0,
- "top_k": 50,
- "max_tokens": 512,
- "stop_token_ids": [151645],
- },
- "prefill_params": {
- "continue_final_message": True,
- "add_generation_prompt": False,
- },
- "prefill_prefix": [
- "May",
- "Could",
- "Sure",
- "Definitely",
- "Certainly",
- "Of course",
- "Can",
- ],
-}
diff --git a/model_server/src/commons/globals.py b/model_server/src/commons/globals.py
index 0dadb2a3..4cf71a31 100644
--- a/model_server/src/commons/globals.py
+++ b/model_server/src/commons/globals.py
@@ -1,36 +1,36 @@
from openai import OpenAI
-from src.commons.constants import *
-from src.core.function_calling import ArchIntentHandler, ArchFunctionHandler
-from src.core.guardrails import get_guardrail_handler
from src.commons.utils import get_model_server_logger
+from src.core.guardrails import get_guardrail_handler
+from src.core.function_calling import (
+ ArchIntentConfig,
+ ArchIntentHandler,
+ ArchFunctionConfig,
+ ArchFunctionHandler,
+)
+# Define logger
logger = get_model_server_logger()
+
# Define the client
ARCH_ENDPOINT = "https://api.fc.archgw.com/v1"
ARCH_API_KEY = "EMPTY"
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
+# Define model names
+ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
+ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
+
+
# Define model handlers
handler_map = {
"Arch-Intent": ArchIntentHandler(
- ARCH_CLIENT,
- ARCH_INTENT_MODEL_ALIAS,
- ARCH_INTENT_TASK_PROMPT,
- ARCH_INTENT_TOOL_PROMPT_TEMPLATE,
- ARCH_INTENT_FORMAT_PROMPT,
- ARCH_INTENT_INSTRUCTION,
- **ARCH_INTENT_GENERATION_CONFIG,
+ ARCH_CLIENT, ARCH_INTENT_MODEL_ALIAS, ArchIntentConfig
),
"Arch-Function": ArchFunctionHandler(
- ARCH_CLIENT,
- ARCH_FUNCTION_MODEL_ALIAS,
- ARCH_FUNCTION_TASK_PROMPT,
- ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE,
- ARCH_FUNCTION_FORMAT_PROMPT,
- **ARCH_FUNCTION_GENERATION_CONFIG,
+ ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
),
"Arch-Guard": get_guardrail_handler(),
}
diff --git a/model_server/src/commons/utils.py b/model_server/src/commons/utils.py
index 5616b369..99fe03af 100644
--- a/model_server/src/commons/utils.py
+++ b/model_server/src/commons/utils.py
@@ -12,7 +12,7 @@ PROJ_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fil
# Default log directory and file
-DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, "logs")
+DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, ".logs")
DEFAULT_LOG_FILE = "modelserver.log"
@@ -50,7 +50,7 @@ def get_model_server_logger(log_dir=None, log_file=None):
Get or initialize the logger instance for the model server.
Parameters:
- - log_dir (str): Custom directory to store the log file. Defaults to `~/archgw_logs`.
+ - log_dir (str): Custom directory to store the log file. Defaults to `./.logs`.
- log_file (str): Custom log file name. Defaults to `modelserver.log`.
Returns:
@@ -146,13 +146,13 @@ def terminate_process_by_pid(pid, timeout):
subprocess.run(["kill", "-9", str(pid)], check=False)
-def find_process_by_port(port=51000):
+def find_processes_by_port(port=51000):
"""Find processes listening on a specific port."""
port_processes = []
try:
- lsof_command = f"lsof -n -i:{port} | grep LISTEN"
+ lsof_command = f"lsof -n | grep {port} | grep -i LISTEN"
result = subprocess.run(
lsof_command, shell=True, capture_output=True, text=True
)
@@ -167,7 +167,7 @@ def find_process_by_port(port=51000):
return []
-def kill_process_by_port(port_processes=51000, wait=True, timeout=10):
+def kill_processes(port_processes, wait=True, timeout=10):
"""Kill processes on a specific port."""
try:
diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py
index 9eb8ccd6..0352ab47 100644
--- a/model_server/src/core/function_calling.py
+++ b/model_server/src/core/function_calling.py
@@ -1,11 +1,12 @@
import json
import random
import builtins
+import textwrap
from openai import OpenAI
from typing import Any, Dict, List, Tuple, Union
from overrides import override
-from src.core.base_handler import (
+from src.core.model_utils import (
Message,
ChatMessage,
Choice,
@@ -14,43 +15,57 @@ from src.core.base_handler import (
)
-SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
+class ArchIntentConfig:
+ TASK_PROMPT = textwrap.dedent(
+ """
+ You are a helpful assistant.
+ """
+ ).strip()
+
+ TOOL_PROMPT_TEMPLATE = textwrap.dedent(
+ """
+ You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
+
+
+ {tool_text}
+
+ """
+ ).strip()
+
+ FORMAT_PROMPT = textwrap.dedent(
+ """
+ Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
+ - First line must read 'Yes' or 'No'.
+ - If yes, a second line must include a comma-separated list of tool indexes.
+ """
+ ).strip()
+
+ EXTRA_INSTRUCTION = "Are there any tools can help?"
+
+ GENERATION_PARAMS = {"max_tokens": 1, "stop_token_ids": [151645]}
class ArchIntentHandler(ArchBaseHandler):
- def __init__(
- self,
- client: OpenAI,
- model_name: str,
- task_prompt: str,
- tool_prompt_template: str,
- format_prompt: str,
- extra_instruction: str,
- generation_params: Dict,
- ):
+ def __init__(self, client: OpenAI, model_name: str, config: ArchIntentConfig):
"""
Initializes the intent handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
- task_prompt (str): The main task prompt for the system.
- tool_prompt_template (str): A prompt to describe tools.
- format_prompt (str): A prompt specifying the desired output format.
- extra_instruction (str): Instructions specific to intent handling.
- generation_params (Dict): Generation parameters for the model.
+ config (ArchIntentConfig): The configuration for Arch-Intent.
"""
super().__init__(
client,
model_name,
- task_prompt,
- tool_prompt_template,
- format_prompt,
- generation_params,
+ config.TASK_PROMPT,
+ config.TOOL_PROMPT_TEMPLATE,
+ config.FORMAT_PROMPT,
+ config.GENERATION_PARAMS,
)
- self.extra_instruction = extra_instruction
+ self.extra_instruction = config.EXTRA_INSTRUCTION
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
@@ -125,17 +140,73 @@ class ArchIntentHandler(ArchBaseHandler):
return chat_completion_response
+# =============================================================================================================
+
+
+class ArchFunctionConfig:
+ TASK_PROMPT = textwrap.dedent(
+ """
+ You are a helpful assistant.
+ """
+ ).strip()
+
+ TOOL_PROMPT_TEMPLATE = textwrap.dedent(
+ """
+ # Tools
+
+ You may call one or more functions to assist with the user query.
+
+ You are provided with function signatures within XML tags:
+
+ {tool_text}
+
+ """
+ ).strip()
+
+ FORMAT_PROMPT = textwrap.dedent(
+ """
+ For each function call, return a json object with function name and arguments within XML tags:
+
+ {"name": , "arguments": }
+
+ """
+ ).strip()
+
+ GENERATION_PARAMS = (
+ {
+ "temperature": 0.2,
+ "top_p": 1.0,
+ "top_k": 50,
+ "max_tokens": 512,
+ "stop_token_ids": [151645],
+ },
+ )
+
+ PREFILL_CONFIG = {
+ "prefill_params": {
+ "continue_final_message": True,
+ "add_generation_prompt": False,
+ },
+ "prefill_prefix": [
+ "May",
+ "Could",
+ "Sure",
+ "Definitely",
+ "Certainly",
+ "Of course",
+ "Can",
+ ],
+ }
+
+ SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
+
+
class ArchFunctionHandler(ArchBaseHandler):
def __init__(
self,
client: OpenAI,
model_name: str,
- task_prompt: str,
- tool_prompt_template: str,
- format_prompt: str,
- generation_params: Dict,
- prefill_params: Dict,
- prefill_prefix: List,
+ config: ArchFunctionConfig,
):
"""
Initializes the function handler.
@@ -143,30 +214,26 @@ class ArchFunctionHandler(ArchBaseHandler):
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
- task_prompt (str): The main task prompt for the system.
- tool_prompt_template (str): A prompt to describe tools.
- format_prompt (str): A prompt specifying the desired output format.
- generation_params (Dict): Generation parameters for the model.
- prefill_params (Dict): Additional parameters for prefilling responses.
- prefill_prefix (List[str]): List of prefixes for prefill responses.
+ config (ArchFunctionConfig): The configuration for Arch-Function
"""
super().__init__(
client,
model_name,
- task_prompt,
- tool_prompt_template,
- format_prompt,
- generation_params,
+ config.TASK_PROMPT,
+ config.TOOL_PROMPT_TEMPLATE,
+ config.FORMAT_PROMPT,
+ config.GENERATION_PARAMS,
)
- self.prefill_params = prefill_params
- self.prefill_prefix = prefill_prefix
+ self.prefill_params = config.PREFILL_CONFIG["prefill_params"]
+ self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"]
# Predefine data types for verification. Only support Python for now.
# [TODO] Extend the list of support data types
self.support_data_types = {
- type_name: getattr(builtins, type_name) for type_name in SUPPORT_DATA_TYPES
+ type_name: getattr(builtins, type_name)
+ for type_name in config.SUPPORT_DATA_TYPES
}
@override
diff --git a/model_server/src/core/guardrails.py b/model_server/src/core/guardrails.py
index 64e283ae..ef6c8e04 100644
--- a/model_server/src/core/guardrails.py
+++ b/model_server/src/core/guardrails.py
@@ -3,22 +3,9 @@ import torch
import numpy as np
import src.commons.utils as utils
-from typing import List
-from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
-
-
-class GuardRequest(BaseModel):
- input: str
- task: str
-
-
-class GuardResponse(BaseModel):
- prob: List
- verdict: bool
- sentence: List
- latency: float = 0
+from src.core.model_utils import GuardRequest, GuardResponse
class ArchGuardHanlder:
diff --git a/model_server/src/core/base_handler.py b/model_server/src/core/model_utils.py
similarity index 94%
rename from model_server/src/core/base_handler.py
rename to model_server/src/core/model_utils.py
index c6b88749..c411d835 100644
--- a/model_server/src/core/base_handler.py
+++ b/model_server/src/core/model_utils.py
@@ -32,6 +32,21 @@ class ChatCompletionResponse(BaseModel):
model: str
+class GuardRequest(BaseModel):
+ input: str
+ task: str
+
+
+class GuardResponse(BaseModel):
+ prob: List
+ verdict: bool
+ sentence: List
+ latency: float = 0
+
+
+# ================================================================================================
+
+
class ArchBaseHandler:
def __init__(
self,
@@ -53,9 +68,7 @@ class ArchBaseHandler:
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
"""
-
self.client = client
-
self.model_name = model_name
self.task_prompt = task_prompt
diff --git a/model_server/src/main.py b/model_server/src/main.py
index 19fef239..0ca8c7c7 100644
--- a/model_server/src/main.py
+++ b/model_server/src/main.py
@@ -1,8 +1,7 @@
import os
from src.commons.globals import handler_map
-from src.core.base_handler import ChatMessage
-from src.core.guardrails import GuardRequest
+from src.core.model_utils import ChatMessage, GuardRequest
from fastapi import FastAPI, Response
from opentelemetry import trace
diff --git a/model_server/tests/core/__init__.py b/model_server/tests/core/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/model_server/tests/core/test_cases.json b/model_server/tests/core/test_cases.json
new file mode 100644
index 00000000..229b4ea1
--- /dev/null
+++ b/model_server/tests/core/test_cases.json
@@ -0,0 +1,949 @@
+[
+ {
+ "case": "tool_call_halluciation",
+ "tokens": [
+ ""
+ ],
+ "expect": 1,
+ "logprobs": [
+ [
+ -0.3333307206630707,
+ -1.5310522317886353,
+ -3.5098977088928223,
+ -3.9004578590393066,
+ -5.775152683258057,
+ -5.814209461212158,
+ -5.9574151039123535,
+ -6.0094895362854,
+ -6.0094895362854,
+ -6.673445224761963
+ ]
+ ]
+ },
+ {
+ "case": "parameter_value_hallucination",
+ "expect": 0,
+ "tokens": [
+ "",
+ "\n",
+ "{'",
+ "name",
+ "':",
+ " '",
+ "get",
+ "_current",
+ "_weather",
+ "',",
+ " '",
+ "arguments",
+ "':",
+ " {'",
+ "location",
+ "':",
+ " '",
+ "Sea",
+ ",",
+ " Australia",
+ "',",
+ " '",
+ "unit",
+ "':",
+ " '",
+ "c",
+ "elsius",
+ "',",
+ " '",
+ "days",
+ "':",
+ " '",
+ "1",
+ "'}}\n",
+ ""
+ ],
+ "logprobs": [
+ [
+ -0.008103232830762863,
+ -5.085402488708496,
+ -6.777836799621582,
+ -7.558959007263184,
+ -9.850253105163574,
+ -10.266852378845215,
+ -10.540244102478027,
+ -10.722506523132324,
+ -10.800618171691895,
+ -10.917786598205566
+ ],
+ [
+ 0.0,
+ -23.25142478942871,
+ -25.139137268066406,
+ -26.2847843170166,
+ -28.992677688598633,
+ -29.070789337158203,
+ -29.55248260498047,
+ -29.91700553894043,
+ -30.20341682434082,
+ -30.307567596435547
+ ],
+ [
+ 0.0,
+ -21.66313934326172,
+ -23.06916046142578,
+ -23.32953453063965,
+ -25.65988540649414,
+ -25.985353469848633,
+ -26.519121170043945,
+ -27.07892417907715,
+ -27.977216720581055,
+ -28.458908081054688
+ ],
+ [
+ 0.0,
+ -28.094383239746094,
+ -28.56305694580078,
+ -29.109844207763672,
+ -29.44832992553711,
+ -31.79170036315918,
+ -32.0,
+ -32.05207443237305,
+ -32.31244659423828,
+ -32.364524841308594
+ ],
+ [
+ 0.0,
+ -30.489830017089844,
+ -31.140766143798828,
+ -31.81774139404297,
+ -34.525634765625,
+ -35.8275032043457,
+ -36.504478454589844,
+ -39.05614471435547,
+ -40.123680114746094,
+ -40.696502685546875
+ ],
+ [
+ 0.0,
+ -25.646865844726562,
+ -26.66232681274414,
+ -27.781936645507812,
+ -28.979660034179688,
+ -31.140764236450195,
+ -31.92188835144043,
+ -31.973962783813477,
+ -33.04149627685547,
+ -33.58828353881836
+ ],
+ [
+ 0.0,
+ -23.511798858642578,
+ -24.136695861816406,
+ -25.230268478393555,
+ -25.777053833007812,
+ -25.80309295654297,
+ -26.45402717590332,
+ -26.636289596557617,
+ -26.740440368652344,
+ -26.896663665771484
+ ],
+ [
+ 0.0,
+ -22.366153717041016,
+ -24.683483123779297,
+ -26.610252380371094,
+ -26.610252380371094,
+ -27.313264846801758,
+ -27.67778778076172,
+ -28.510986328125,
+ -28.615135192871094,
+ -29.13588523864746
+ ],
+ [
+ 0.0,
+ -22.52237319946289,
+ -24.292919158935547,
+ -24.344993591308594,
+ -24.39706802368164,
+ -24.73555564880371,
+ -29.943042755126953,
+ -29.969079971313477,
+ -30.021154403686523,
+ -30.0341739654541
+ ],
+ [
+ 0.0,
+ -30.17738151550293,
+ -30.411718368530273,
+ -30.88039207458496,
+ -30.984540939331055,
+ -31.270952224731445,
+ -31.895851135253906,
+ -32.46867370605469,
+ -32.624900817871094,
+ -33.484134674072266
+ ],
+ [
+ 0.0,
+ -28.146459579467773,
+ -29.396255493164062,
+ -30.099267959594727,
+ -31.127744674682617,
+ -31.179821014404297,
+ -32.807159423828125,
+ -33.7445068359375,
+ -33.770545959472656,
+ -34.069976806640625
+ ],
+ [
+ 0.0,
+ -26.323841094970703,
+ -26.558177947998047,
+ -30.515867233276367,
+ -30.932466506958008,
+ -31.37510108947754,
+ -31.531326293945312,
+ -31.70056915283203,
+ -32.065093994140625,
+ -32.364524841308594
+ ],
+ [
+ 0.0,
+ -26.922698974609375,
+ -30.28152847290039,
+ -31.505287170410156,
+ -33.30187225341797,
+ -33.73148727416992,
+ -34.27827453613281,
+ -34.33034896850586,
+ -34.460533142089844,
+ -34.720909118652344
+ ],
+ [
+ 0.0,
+ -21.532955169677734,
+ -26.94873809814453,
+ -29.109848022460938,
+ -30.80228042602539,
+ -31.55736541748047,
+ -33.484134674072266,
+ -34.681854248046875,
+ -35.384864807128906,
+ -35.853538513183594
+ ],
+ [
+ 0.0,
+ -19.502033233642578,
+ -20.46541976928711,
+ -24.110658645629883,
+ -24.501218795776367,
+ -25.256305694580078,
+ -25.82912826538086,
+ -25.881202697753906,
+ -26.063465118408203,
+ -26.063465118408203
+ ],
+ [
+ 0.0,
+ -24.37103271484375,
+ -25.256305694580078,
+ -25.933277130126953,
+ -26.714401245117188,
+ -28.2506103515625,
+ -31.010576248168945,
+ -32.07810974121094,
+ -34.62977981567383,
+ -35.241661071777344
+ ],
+ [
+ -1.1920922133867862e-06,
+ -14.398697853088379,
+ -14.424736976623535,
+ -17.158666610717773,
+ -17.41904067993164,
+ -18.200162887573242,
+ -18.434499740600586,
+ -18.66883659362793,
+ -19.71033477783203,
+ -19.71033477783203
+ ],
+ [
+ -0.0001445904199499637,
+ -8.98305892944336,
+ -11.35246467590332,
+ -13.1490478515625,
+ -13.669795989990234,
+ -14.073375701904297,
+ -14.516012191772461,
+ -14.555068969726562,
+ -15.622602462768555,
+ -15.635622024536133
+ ],
+ [
+ -0.44747352600097656,
+ -1.0202960968017578,
+ -8.467000961303711,
+ -10.914518356323242,
+ -11.25300407409668,
+ -11.435266494750977,
+ -12.346576690673828,
+ -13.075624465942383,
+ -13.12769889831543,
+ -13.231849670410156
+ ],
+ [
+ -3.123767137527466,
+ -1.1188862323760986,
+ -1.639634370803833,
+ -2.0562336444854736,
+ -2.8633930683135986,
+ -2.9675419330596924,
+ -3.4882919788360596,
+ -3.69659161567688,
+ -4.217339515686035,
+ -4.243376731872559
+ ],
+ [
+ -7.199982064776123e-05,
+ -9.76410961151123,
+ -11.144091606140137,
+ -16.507802963256836,
+ -17.132701873779297,
+ -17.44515037536621,
+ -17.9138240814209,
+ -18.33042335510254,
+ -18.9162654876709,
+ -19.39795684814453
+ ],
+ [
+ 0.0,
+ -22.991050720214844,
+ -23.824249267578125,
+ -24.969894409179688,
+ -25.46460723876953,
+ -25.829130172729492,
+ -26.480066299438477,
+ -26.909683227539062,
+ -27.33930206298828,
+ -27.391376495361328
+ ],
+ [
+ -0.21928852796554565,
+ -1.625309705734253,
+ -9.775025367736816,
+ -12.977627754211426,
+ -16.388530731201172,
+ -17.091541290283203,
+ -19.044347763061523,
+ -19.38283348083496,
+ -19.460947036743164,
+ -19.59113311767578
+ ],
+ [
+ 0.0,
+ -24.006507873535156,
+ -27.443450927734375,
+ -27.729862213134766,
+ -28.12042236328125,
+ -28.276647567749023,
+ -28.927583694458008,
+ -30.099267959594727,
+ -31.479251861572266,
+ -32.07810974121094
+ ],
+ [
+ 0.0,
+ -18.17412567138672,
+ -18.772987365722656,
+ -21.689178466796875,
+ -21.92351531982422,
+ -23.7200984954834,
+ -23.79821014404297,
+ -23.79821014404297,
+ -24.032546997070312,
+ -25.308382034301758
+ ],
+ [
+ -0.12947827577590942,
+ -2.1083219051361084,
+ -12.419143676757812,
+ -15.23118782043457,
+ -15.595710754394531,
+ -15.830047607421875,
+ -17.001731872558594,
+ -17.60059356689453,
+ -18.121341705322266,
+ -18.251529693603516
+ ],
+ [
+ 0.0,
+ -19.449962615966797,
+ -24.371034622192383,
+ -24.917821884155273,
+ -25.529701232910156,
+ -25.85516929626465,
+ -26.037429809570312,
+ -26.115543365478516,
+ -26.623271942138672,
+ -26.649309158325195
+ ],
+ [
+ -0.03332124650478363,
+ -3.4181859493255615,
+ -15.759925842285156,
+ -15.812002182006836,
+ -16.593124389648438,
+ -17.894996643066406,
+ -18.09027671813965,
+ -18.79328727722168,
+ -19.144792556762695,
+ -20.147233963012695
+ ],
+ [
+ 0.0,
+ -21.142393112182617,
+ -22.157852172851562,
+ -23.511798858642578,
+ -24.657445907592773,
+ -25.021968841552734,
+ -25.5427188873291,
+ -25.59479331970215,
+ -25.75101661682129,
+ -25.95931625366211
+ ],
+ [
+ 0.0,
+ -23.04312515258789,
+ -24.94385528564453,
+ -26.323841094970703,
+ -27.54759979248047,
+ -28.563060760498047,
+ -29.786819458007812,
+ -30.620018005371094,
+ -30.69812774658203,
+ -31.08869171142578
+ ],
+ [
+ 0.0,
+ -26.167617797851562,
+ -28.771360397338867,
+ -29.55248260498047,
+ -30.906429290771484,
+ -31.114728927612305,
+ -31.414159774780273,
+ -31.622459411621094,
+ -31.713590621948242,
+ -31.726608276367188
+ ],
+ [
+ -0.05012698099017143,
+ -3.018392562866211,
+ -11.740934371948242,
+ -13.146955490112305,
+ -13.797887802124023,
+ -14.943536758422852,
+ -16.037107467651367,
+ -16.375595092773438,
+ -16.714080810546875,
+ -17.36501693725586
+ ],
+ [
+ -0.9704352021217346,
+ -0.7360983490943909,
+ -2.1941938400268555,
+ -4.225115776062012,
+ -5.0062360763549805,
+ -5.2666120529174805,
+ -5.839434623718262,
+ -7.2714948654174805,
+ -8.33902645111084,
+ -8.495253562927246
+ ],
+ [
+ -0.014467108063399792,
+ -4.258565902709961,
+ -8.789079666137695,
+ -10.429437637329102,
+ -10.793962478637695,
+ -11.835458755493164,
+ -11.939607620239258,
+ -13.31959342956543,
+ -13.866378784179688,
+ -15.038063049316406
+ ],
+ [
+ 0.0,
+ -20.08787727355957,
+ -21.350692749023438,
+ -21.415786743164062,
+ -21.50691795349121,
+ -21.50691795349121,
+ -22.7176570892334,
+ -24.13669776916504,
+ -24.188772201538086,
+ -24.34499740600586
+ ]
+ ]
+ },
+ {
+ "case": "fail_case",
+ "expect": 0,
+ "tokens": [
+ "",
+ "\n",
+ "{'",
+ "name",
+ "':",
+ " '",
+ "get",
+ "_current",
+ "_weather",
+ "',",
+ " '",
+ "arguments",
+ "':",
+ " {'",
+ "location",
+ "':",
+ " '",
+ "Seattle",
+ ",",
+ " WA",
+ "',",
+ " '",
+ "unit",
+ "':",
+ " '",
+ "c",
+ "elsius",
+ "',",
+ " '",
+ "days",
+ "':",
+ " '",
+ "7",
+ "'}}\n",
+ ""
+ ],
+ "logprobs": [
+ [
+ -0.00013815402053296566,
+ -9.113236427307129,
+ -10.571331977844238,
+ -14.099404335021973,
+ -14.28166675567627,
+ -15.583537101745605,
+ -15.81787395477295,
+ -16.143341064453125,
+ -16.143341064453125,
+ -16.260509490966797
+ ],
+ [
+ 0.0,
+ -26.896663665771484,
+ -27.32628059387207,
+ -27.41741180419922,
+ -32.07810974121094,
+ -32.07810974121094,
+ -32.28641128540039,
+ -32.29943084716797,
+ -32.44263458251953,
+ -32.520748138427734
+ ],
+ [
+ 0.0,
+ -22.444263458251953,
+ -24.527257919311523,
+ -27.15703773498535,
+ -28.016273498535156,
+ -28.2506103515625,
+ -28.693246841430664,
+ -29.070789337158203,
+ -29.565500259399414,
+ -29.812854766845703
+ ],
+ [
+ 0.0,
+ -27.860050201416016,
+ -28.641170501708984,
+ -29.448333740234375,
+ -30.932466506958008,
+ -31.63547706604004,
+ -32.33848571777344,
+ -32.85923767089844,
+ -33.17168426513672,
+ -33.45809555053711
+ ],
+ [
+ 0.0,
+ -31.81774139404297,
+ -31.895854949951172,
+ -32.05207824707031,
+ -35.43694305419922,
+ -36.3482551574707,
+ -38.61351013183594,
+ -39.26444625854492,
+ -40.61839294433594,
+ -41.71196365356445
+ ],
+ [
+ 0.0,
+ -27.33930206298828,
+ -27.834014892578125,
+ -28.849472045898438,
+ -30.567943572998047,
+ -32.98942565917969,
+ -33.067535400390625,
+ -33.067535400390625,
+ -35.67127990722656,
+ -35.69731903076172
+ ],
+ [
+ 0.0,
+ -25.33441925048828,
+ -26.063465118408203,
+ -26.219690322875977,
+ -26.2457275390625,
+ -26.53213882446289,
+ -27.365337371826172,
+ -28.354759216308594,
+ -28.667207717895508,
+ -28.74532127380371
+ ],
+ [
+ 0.0,
+ -24.423107147216797,
+ -24.579330444335938,
+ -26.81855010986328,
+ -28.12042236328125,
+ -28.32872200012207,
+ -28.61513328552246,
+ -29.16191864013672,
+ -29.187957763671875,
+ -29.240032196044922
+ ],
+ [
+ 0.0,
+ -22.027664184570312,
+ -23.850284576416016,
+ -23.980472564697266,
+ -24.292922973632812,
+ -24.787633895874023,
+ -29.279088973999023,
+ -29.55248260498047,
+ -29.903987884521484,
+ -30.190399169921875
+ ],
+ [
+ 0.0,
+ -31.609439849853516,
+ -31.817739486694336,
+ -32.54678726196289,
+ -32.676971435546875,
+ -32.781124114990234,
+ -32.98942565917969,
+ -33.106590270996094,
+ -33.57526397705078,
+ -34.369407653808594
+ ],
+ [
+ 0.0,
+ -29.34418296813965,
+ -29.63059425354004,
+ -30.021156311035156,
+ -30.984540939331055,
+ -33.21073913574219,
+ -34.30431365966797,
+ -34.56468963623047,
+ -34.70789337158203,
+ -34.79902648925781
+ ],
+ [
+ 0.0,
+ -25.438566207885742,
+ -25.69894027709961,
+ -30.190397262573242,
+ -30.802276611328125,
+ -31.58340072631836,
+ -31.609437942504883,
+ -31.64849281311035,
+ -31.973960876464844,
+ -32.29943084716797
+ ],
+ [
+ 0.0,
+ -27.157039642333984,
+ -32.104148864746094,
+ -32.33848571777344,
+ -34.04393768310547,
+ -34.12205505371094,
+ -34.40846252441406,
+ -34.42148208618164,
+ -34.772987365722656,
+ -34.87713623046875
+ ],
+ [
+ 0.0,
+ -24.813671112060547,
+ -26.974777221679688,
+ -31.010578155517578,
+ -31.08869171142578,
+ -32.1822624206543,
+ -35.33279037475586,
+ -35.489013671875,
+ -36.999183654785156,
+ -37.88446044921875
+ ],
+ [
+ 0.0,
+ -20.46541976928711,
+ -20.647682189941406,
+ -23.069164276123047,
+ -24.136699676513672,
+ -25.438570022583008,
+ -25.646869659423828,
+ -26.193655014038086,
+ -26.297805786132812,
+ -26.506103515625
+ ],
+ [
+ 0.0,
+ -27.18307113647461,
+ -28.30268096923828,
+ -28.56305694580078,
+ -29.526439666748047,
+ -32.416595458984375,
+ -35.202598571777344,
+ -36.426361083984375,
+ -39.31651306152344,
+ -39.38160705566406
+ ],
+ [
+ 0.0,
+ -18.7469482421875,
+ -20.100894927978516,
+ -21.402767181396484,
+ -21.428804397583008,
+ -22.20992660522461,
+ -22.34011459350586,
+ -22.730674743652344,
+ -23.069162368774414,
+ -23.980472564697266
+ ],
+ [
+ -3.576278118089249e-07,
+ -15.2579345703125,
+ -16.481693267822266,
+ -17.991863250732422,
+ -19.215621948242188,
+ -20.25712013244629,
+ -21.350692749023438,
+ -22.314077377319336,
+ -22.496337890625,
+ -22.938974380493164
+ ],
+ [
+ -0.08506780862808228,
+ -2.506549835205078,
+ -14.848289489746094,
+ -15.473188400268555,
+ -16.33242416381836,
+ -16.358461380004883,
+ -16.566761016845703,
+ -17.03543472290039,
+ -17.686370849609375,
+ -17.816556930541992
+ ],
+ [
+ -0.0194891095161438,
+ -4.445854187011719,
+ -5.591499328613281,
+ -5.956024169921875,
+ -6.685070037841797,
+ -13.142353057861328,
+ -13.558952331542969,
+ -15.173273086547852,
+ -15.303461074829102,
+ -15.85024642944336
+ ],
+ [
+ -0.0005990855861455202,
+ -7.4212646484375,
+ -15.675132751464844,
+ -15.72720718383789,
+ -16.76870346069336,
+ -16.76870346069336,
+ -17.706050872802734,
+ -18.669435501098633,
+ -19.398483276367188,
+ -19.658857345581055
+ ],
+ [
+ 0.0,
+ -24.110658645629883,
+ -25.829130172729492,
+ -26.011390686035156,
+ -26.011390686035156,
+ -26.532140731811523,
+ -26.58421516418457,
+ -27.651750564575195,
+ -27.75589942932129,
+ -28.055330276489258
+ ],
+ [
+ -1.1408883333206177,
+ -0.38580334186553955,
+ -7.494022369384766,
+ -12.519245147705078,
+ -14.576202392578125,
+ -16.034297943115234,
+ -16.945608139038086,
+ -17.908992767333984,
+ -18.664077758789062,
+ -19.34105110168457
+ ],
+ [
+ 0.0,
+ -26.688365936279297,
+ -29.83889389038086,
+ -30.177383422851562,
+ -30.64605712890625,
+ -31.244916915893555,
+ -31.270954132080078,
+ -32.83319854736328,
+ -34.655818939208984,
+ -34.89015579223633
+ ],
+ [
+ 0.0,
+ -18.929210662841797,
+ -19.16354751586914,
+ -23.589908599853516,
+ -24.683481216430664,
+ -24.995929718017578,
+ -25.516677856445312,
+ -25.542715072631836,
+ -25.77705192565918,
+ -26.063465118408203
+ ],
+ [
+ -0.2519786059856415,
+ -1.5017764568328857,
+ -12.437495231628418,
+ -15.457839012145996,
+ -15.744250297546387,
+ -16.837820053100586,
+ -17.41064453125,
+ -17.56686782836914,
+ -17.61894416809082,
+ -18.035541534423828
+ ],
+ [
+ 0.0,
+ -20.517494201660156,
+ -24.683483123779297,
+ -25.67290496826172,
+ -26.58421516418457,
+ -27.651750564575195,
+ -27.781936645507812,
+ -27.912124633789062,
+ -28.09438705444336,
+ -28.445892333984375
+ ],
+ [
+ -3.40932747349143e-05,
+ -10.284820556640625,
+ -18.252273559570312,
+ -20.17904281616211,
+ -21.663175582885742,
+ -22.027700424194336,
+ -22.288074493408203,
+ -22.704673767089844,
+ -23.12127113342285,
+ -23.277496337890625
+ ],
+ [
+ 0.0,
+ -22.60049057006836,
+ -25.46460723876953,
+ -25.829130172729492,
+ -26.063467025756836,
+ -27.287227630615234,
+ -27.391376495361328,
+ -27.4694881439209,
+ -27.67778778076172,
+ -28.055330276489258
+ ],
+ [
+ 0.0,
+ -23.902362823486328,
+ -28.823436737060547,
+ -29.240036010742188,
+ -29.31814956665039,
+ -29.917007446289062,
+ -30.021160125732422,
+ -31.21887969970703,
+ -32.416603088378906,
+ -32.416603088378906
+ ],
+ [
+ 0.0,
+ -28.641170501708984,
+ -31.947925567626953,
+ -32.59886169433594,
+ -33.848655700683594,
+ -34.109031677246094,
+ -34.73393249511719,
+ -35.02033996582031,
+ -35.02033996582031,
+ -36.074859619140625
+ ],
+ [
+ -0.013183215633034706,
+ -4.335395336151123,
+ -19.619365692138672,
+ -20.035964965820312,
+ -20.244266510009766,
+ -21.311800003051758,
+ -21.441987991333008,
+ -22.561595916748047,
+ -23.108383178710938,
+ -23.264606475830078
+ ],
+ [
+ -8.344646857949556e-07,
+ -14.190400123596191,
+ -15.9088716506958,
+ -18.17412567138672,
+ -18.46053695678711,
+ -18.46053695678711,
+ -18.512611389160156,
+ -18.90317153930664,
+ -19.059398651123047,
+ -19.085433959960938
+ ],
+ [
+ 0.0,
+ -17.70545196533203,
+ -18.903175354003906,
+ -20.829944610595703,
+ -22.574451446533203,
+ -22.860862731933594,
+ -23.069162368774414,
+ -23.32953643798828,
+ -23.694061279296875,
+ -24.188772201538086
+ ],
+ [
+ 0.0,
+ -20.022781372070312,
+ -21.038240432739258,
+ -21.220502853393555,
+ -22.496337890625,
+ -22.769729614257812,
+ -23.589908599853516,
+ -23.65500259399414,
+ -23.94141387939453,
+ -24.266881942749023
+ ]
+ ]
+ }
+]
diff --git a/model_server/tests/core/test_function_calling.py b/model_server/tests/core/test_function_calling.py
index c59a28dc..2c786809 100644
--- a/model_server/tests/core/test_function_calling.py
+++ b/model_server/tests/core/test_function_calling.py
@@ -4,7 +4,7 @@ import pytest
from fastapi import Response
from unittest.mock import AsyncMock, MagicMock, patch
from src.commons.globals import handler_map
-from src.core.base_handler import (
+from src.core.model_utils import (
Message,
ChatMessage,
ChatCompletionResponse,
@@ -47,7 +47,7 @@ def sample_request(sample_messages):
)
-@patch("app.commons.globals.handler_map")
+@patch("src.commons.globals.handler_map")
def test_process_messages(mock_hanlder):
messages = sample_messages()
processed = handler_map["Arch-Function"]._process_messages(messages)
@@ -68,39 +68,39 @@ def test_process_messages(mock_hanlder):
# [TODO] Review: Update the following test
-@patch("app.commons.constants.arch_function_client")
-@patch("app.commons.constants.arch_function_hanlder")
-@pytest.mark.asyncio
-async def test_chat_completion(mock_hanlder, mock_client):
- # Mock the model list return for client
- mock_client.models.list.return_value = MagicMock(
- data=[MagicMock(id="sample_model")]
- )
- request = sample_request(sample_messages())
- # Simulate stream response as list of tokens
- mock_response = AsyncMock()
- mock_response.__aiter__.return_value = [
- MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]),
- MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream
- ]
- mock_client.chat.completions.create.return_value = mock_response
+# @patch("src.commons.globals.ARCH_CLIENT")
+# @patch("src.commons.constants.handler_map")
+# @pytest.mark.asyncio
+# async def test_chat_completion(mock_hanlder, mock_client):
+# # Mock the model list return for client
+# mock_client.models.list.return_value = MagicMock(
+# data=[MagicMock(id="sample_model")]
+# )
+# request = sample_request(sample_messages())
+# # Simulate stream response as list of tokens
+# mock_response = AsyncMock()
+# mock_response.__aiter__.return_value = [
+# MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]),
+# MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream
+# ]
+# mock_client.chat.completions.create.return_value = mock_response
- # Mock the tool formatter
- mock_hanlder._format_system_prompt.return_value = ""
+# # Mock the tool formatter
+# mock_hanlder._format_system_prompt.return_value = ""
- response = Response()
- chat_response = await chat_completion(request, response)
+# response = Response()
+# chat_response = await chat_completion(request, response)
- assert isinstance(chat_response, ChatCompletionResponse)
- assert chat_response.choices[0].message.content is not None
+# assert isinstance(chat_response, ChatCompletionResponse)
+# assert chat_response.choices[0].message.content is not None
- first_call_args = mock_client.chat.completions.create.call_args_list[0][1]
- assert first_call_args["stream"] == True
- assert "model" in first_call_args
- assert first_call_args["messages"][0]["content"] == ""
+# first_call_args = mock_client.chat.completions.create.call_args_list[0][1]
+# assert first_call_args["stream"] == True
+# assert "model" in first_call_args
+# assert first_call_args["messages"][0]["content"] == ""
- # Check that the arguments for the second call to 'create' include the pre-fill completion
- second_call_args = mock_client.chat.completions.create.call_args_list[1][1]
- assert second_call_args["stream"] == False
- assert "model" in second_call_args
- assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST
+# # Check that the arguments for the second call to 'create' include the pre-fill completion
+# second_call_args = mock_client.chat.completions.create.call_args_list[1][1]
+# assert second_call_args["stream"] == False
+# assert "model" in second_call_args
+# assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST
diff --git a/model_server/tests/core/test_guardrails.py b/model_server/tests/core/test_guardrails.py
index 5ba7ad11..b2087580 100644
--- a/model_server/tests/core/test_guardrails.py
+++ b/model_server/tests/core/test_guardrails.py
@@ -11,9 +11,9 @@ arch_guard_model_type = {
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
# Test for `get_guardrail_handler()` function on `cpu`
-@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
-@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
-@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
+@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
+@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
+@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "cpu"
@@ -34,9 +34,9 @@ def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer
# Test for `get_guardrail_handler()` function on `cuda`
-@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
-@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
-@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
+@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
+@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
+@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "cuda"
@@ -57,9 +57,9 @@ def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenize
# Test for `get_guardrail_handler()` function on `mps`
-@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
-@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
-@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
+@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
+@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
+@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "mps"