From 95e167c2f62e99d6049e51345babcd835a138887 Mon Sep 17 00:00:00 2001 From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com> Date: Sun, 8 Dec 2024 15:43:19 -0800 Subject: [PATCH] Refacotr model configs --- model_server/src/cli.py | 12 +- model_server/src/commons/constants.py | 79 -- model_server/src/commons/globals.py | 32 +- model_server/src/commons/utils.py | 10 +- model_server/src/core/function_calling.py | 149 ++- model_server/src/core/guardrails.py | 15 +- .../core/{base_handler.py => model_utils.py} | 17 +- model_server/src/main.py | 3 +- model_server/tests/core/__init__.py | 0 model_server/tests/core/test_cases.json | 949 ++++++++++++++++++ .../tests/core/test_function_calling.py | 66 +- model_server/tests/core/test_guardrails.py | 18 +- 12 files changed, 1144 insertions(+), 206 deletions(-) delete mode 100644 model_server/src/commons/constants.py rename model_server/src/core/{base_handler.py => model_utils.py} (94%) create mode 100644 model_server/tests/core/__init__.py create mode 100644 model_server/tests/core/test_cases.json diff --git a/model_server/src/cli.py b/model_server/src/cli.py index d863d028..8d032961 100644 --- a/model_server/src/cli.py +++ b/model_server/src/cli.py @@ -8,8 +8,8 @@ from src.commons.utils import ( wait_for_health_check, check_lsof, install_lsof, - find_process_by_port, - kill_process_by_port, + find_processes_by_port, + kill_processes, ) @@ -23,7 +23,7 @@ def start_server(port=51000): "python", "-m", "uvicorn", - "app.main:app", + "src.main:app", "--host", "0.0.0.0", "--port", @@ -56,14 +56,16 @@ def stop_server(port=51000, wait=True, timeout=10): sys.exit(1) logger.info(f"Stopping processes on port {port}...") - port_processes = find_process_by_port(port) + port_processes = find_processes_by_port(port) if port_processes is None: logger.info(f"No processes found listening on port {port}.") else: if len(port_processes): - process_killed = kill_process_by_port(port_processes, wait, timeout) + process_killed = kill_processes(port_processes, wait, timeout) if not process_killed: logger.error(f"Unable to kill all processes on {port}") + else: + logger.info(f"All processes on port {port} have been killed.") else: logger.error(f"Unable to find processes on {port}") diff --git a/model_server/src/commons/constants.py b/model_server/src/commons/constants.py deleted file mode 100644 index 1b9035e0..00000000 --- a/model_server/src/commons/constants.py +++ /dev/null @@ -1,79 +0,0 @@ -# ========================== Arch-Intent Default Params ========================== -ARCH_INTENT_MODEL_ALIAS = "Arch-Intent" -ARCH_INTENT_INSTRUCTION = "Are there any tools can help?" - -ARCH_INTENT_TASK_PROMPT = """ -You are a helpful assistant. -""" - - -ARCH_INTENT_TOOL_PROMPT_TEMPLATE = """ -You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below. - - -{tool_text} - -""" - - -ARCH_INTENT_FORMAT_PROMPT = """ -Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation: -- First line must read 'Yes' or 'No'. -- If yes, a second line must include a comma-separated list of tool indexes. -""" - - -ARCH_INTENT_GENERATION_CONFIG = { - "generation_params": {"max_tokens": 1, "stop_token_ids": [151645]} -} - - -# ========================== Arch-Function Default Params ========================== -ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function" - -ARCH_FUNCTION_TASK_PROMPT = """ -You are a helpful assistant. -""" - - -ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE = """ -# Tools - -You may call one or more functions to assist with the user query. - -You are provided with function signatures within XML tags: - -{tool_text} - -""" - - -ARCH_FUNCTION_FORMAT_PROMPT = """ -For each function call, return a json object with function name and arguments within XML tags: - -{"name": , "arguments": } - -""" - -ARCH_FUNCTION_GENERATION_CONFIG = { - "generation_params": { - "temperature": 0.2, - "top_p": 1.0, - "top_k": 50, - "max_tokens": 512, - "stop_token_ids": [151645], - }, - "prefill_params": { - "continue_final_message": True, - "add_generation_prompt": False, - }, - "prefill_prefix": [ - "May", - "Could", - "Sure", - "Definitely", - "Certainly", - "Of course", - "Can", - ], -} diff --git a/model_server/src/commons/globals.py b/model_server/src/commons/globals.py index 0dadb2a3..4cf71a31 100644 --- a/model_server/src/commons/globals.py +++ b/model_server/src/commons/globals.py @@ -1,36 +1,36 @@ from openai import OpenAI -from src.commons.constants import * -from src.core.function_calling import ArchIntentHandler, ArchFunctionHandler -from src.core.guardrails import get_guardrail_handler from src.commons.utils import get_model_server_logger +from src.core.guardrails import get_guardrail_handler +from src.core.function_calling import ( + ArchIntentConfig, + ArchIntentHandler, + ArchFunctionConfig, + ArchFunctionHandler, +) +# Define logger logger = get_model_server_logger() + # Define the client ARCH_ENDPOINT = "https://api.fc.archgw.com/v1" ARCH_API_KEY = "EMPTY" ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY) +# Define model names +ARCH_INTENT_MODEL_ALIAS = "Arch-Intent" +ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function" + + # Define model handlers handler_map = { "Arch-Intent": ArchIntentHandler( - ARCH_CLIENT, - ARCH_INTENT_MODEL_ALIAS, - ARCH_INTENT_TASK_PROMPT, - ARCH_INTENT_TOOL_PROMPT_TEMPLATE, - ARCH_INTENT_FORMAT_PROMPT, - ARCH_INTENT_INSTRUCTION, - **ARCH_INTENT_GENERATION_CONFIG, + ARCH_CLIENT, ARCH_INTENT_MODEL_ALIAS, ArchIntentConfig ), "Arch-Function": ArchFunctionHandler( - ARCH_CLIENT, - ARCH_FUNCTION_MODEL_ALIAS, - ARCH_FUNCTION_TASK_PROMPT, - ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE, - ARCH_FUNCTION_FORMAT_PROMPT, - **ARCH_FUNCTION_GENERATION_CONFIG, + ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig ), "Arch-Guard": get_guardrail_handler(), } diff --git a/model_server/src/commons/utils.py b/model_server/src/commons/utils.py index 5616b369..99fe03af 100644 --- a/model_server/src/commons/utils.py +++ b/model_server/src/commons/utils.py @@ -12,7 +12,7 @@ PROJ_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fil # Default log directory and file -DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, "logs") +DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, ".logs") DEFAULT_LOG_FILE = "modelserver.log" @@ -50,7 +50,7 @@ def get_model_server_logger(log_dir=None, log_file=None): Get or initialize the logger instance for the model server. Parameters: - - log_dir (str): Custom directory to store the log file. Defaults to `~/archgw_logs`. + - log_dir (str): Custom directory to store the log file. Defaults to `./.logs`. - log_file (str): Custom log file name. Defaults to `modelserver.log`. Returns: @@ -146,13 +146,13 @@ def terminate_process_by_pid(pid, timeout): subprocess.run(["kill", "-9", str(pid)], check=False) -def find_process_by_port(port=51000): +def find_processes_by_port(port=51000): """Find processes listening on a specific port.""" port_processes = [] try: - lsof_command = f"lsof -n -i:{port} | grep LISTEN" + lsof_command = f"lsof -n | grep {port} | grep -i LISTEN" result = subprocess.run( lsof_command, shell=True, capture_output=True, text=True ) @@ -167,7 +167,7 @@ def find_process_by_port(port=51000): return [] -def kill_process_by_port(port_processes=51000, wait=True, timeout=10): +def kill_processes(port_processes, wait=True, timeout=10): """Kill processes on a specific port.""" try: diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 9eb8ccd6..0352ab47 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -1,11 +1,12 @@ import json import random import builtins +import textwrap from openai import OpenAI from typing import Any, Dict, List, Tuple, Union from overrides import override -from src.core.base_handler import ( +from src.core.model_utils import ( Message, ChatMessage, Choice, @@ -14,43 +15,57 @@ from src.core.base_handler import ( ) -SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"] +class ArchIntentConfig: + TASK_PROMPT = textwrap.dedent( + """ + You are a helpful assistant. + """ + ).strip() + + TOOL_PROMPT_TEMPLATE = textwrap.dedent( + """ + You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below. + + + {tool_text} + + """ + ).strip() + + FORMAT_PROMPT = textwrap.dedent( + """ + Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation: + - First line must read 'Yes' or 'No'. + - If yes, a second line must include a comma-separated list of tool indexes. + """ + ).strip() + + EXTRA_INSTRUCTION = "Are there any tools can help?" + + GENERATION_PARAMS = {"max_tokens": 1, "stop_token_ids": [151645]} class ArchIntentHandler(ArchBaseHandler): - def __init__( - self, - client: OpenAI, - model_name: str, - task_prompt: str, - tool_prompt_template: str, - format_prompt: str, - extra_instruction: str, - generation_params: Dict, - ): + def __init__(self, client: OpenAI, model_name: str, config: ArchIntentConfig): """ Initializes the intent handler. Args: client (OpenAI): An OpenAI client instance. model_name (str): Name of the model to use. - task_prompt (str): The main task prompt for the system. - tool_prompt_template (str): A prompt to describe tools. - format_prompt (str): A prompt specifying the desired output format. - extra_instruction (str): Instructions specific to intent handling. - generation_params (Dict): Generation parameters for the model. + config (ArchIntentConfig): The configuration for Arch-Intent. """ super().__init__( client, model_name, - task_prompt, - tool_prompt_template, - format_prompt, - generation_params, + config.TASK_PROMPT, + config.TOOL_PROMPT_TEMPLATE, + config.FORMAT_PROMPT, + config.GENERATION_PARAMS, ) - self.extra_instruction = extra_instruction + self.extra_instruction = config.EXTRA_INSTRUCTION @override def _convert_tools(self, tools: List[Dict[str, Any]]) -> str: @@ -125,17 +140,73 @@ class ArchIntentHandler(ArchBaseHandler): return chat_completion_response +# ============================================================================================================= + + +class ArchFunctionConfig: + TASK_PROMPT = textwrap.dedent( + """ + You are a helpful assistant. + """ + ).strip() + + TOOL_PROMPT_TEMPLATE = textwrap.dedent( + """ + # Tools + + You may call one or more functions to assist with the user query. + + You are provided with function signatures within XML tags: + + {tool_text} + + """ + ).strip() + + FORMAT_PROMPT = textwrap.dedent( + """ + For each function call, return a json object with function name and arguments within XML tags: + + {"name": , "arguments": } + + """ + ).strip() + + GENERATION_PARAMS = ( + { + "temperature": 0.2, + "top_p": 1.0, + "top_k": 50, + "max_tokens": 512, + "stop_token_ids": [151645], + }, + ) + + PREFILL_CONFIG = { + "prefill_params": { + "continue_final_message": True, + "add_generation_prompt": False, + }, + "prefill_prefix": [ + "May", + "Could", + "Sure", + "Definitely", + "Certainly", + "Of course", + "Can", + ], + } + + SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"] + + class ArchFunctionHandler(ArchBaseHandler): def __init__( self, client: OpenAI, model_name: str, - task_prompt: str, - tool_prompt_template: str, - format_prompt: str, - generation_params: Dict, - prefill_params: Dict, - prefill_prefix: List, + config: ArchFunctionConfig, ): """ Initializes the function handler. @@ -143,30 +214,26 @@ class ArchFunctionHandler(ArchBaseHandler): Args: client (OpenAI): An OpenAI client instance. model_name (str): Name of the model to use. - task_prompt (str): The main task prompt for the system. - tool_prompt_template (str): A prompt to describe tools. - format_prompt (str): A prompt specifying the desired output format. - generation_params (Dict): Generation parameters for the model. - prefill_params (Dict): Additional parameters for prefilling responses. - prefill_prefix (List[str]): List of prefixes for prefill responses. + config (ArchFunctionConfig): The configuration for Arch-Function """ super().__init__( client, model_name, - task_prompt, - tool_prompt_template, - format_prompt, - generation_params, + config.TASK_PROMPT, + config.TOOL_PROMPT_TEMPLATE, + config.FORMAT_PROMPT, + config.GENERATION_PARAMS, ) - self.prefill_params = prefill_params - self.prefill_prefix = prefill_prefix + self.prefill_params = config.PREFILL_CONFIG["prefill_params"] + self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"] # Predefine data types for verification. Only support Python for now. # [TODO] Extend the list of support data types self.support_data_types = { - type_name: getattr(builtins, type_name) for type_name in SUPPORT_DATA_TYPES + type_name: getattr(builtins, type_name) + for type_name in config.SUPPORT_DATA_TYPES } @override diff --git a/model_server/src/core/guardrails.py b/model_server/src/core/guardrails.py index 64e283ae..ef6c8e04 100644 --- a/model_server/src/core/guardrails.py +++ b/model_server/src/core/guardrails.py @@ -3,22 +3,9 @@ import torch import numpy as np import src.commons.utils as utils -from typing import List -from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification from optimum.intel import OVModelForSequenceClassification - - -class GuardRequest(BaseModel): - input: str - task: str - - -class GuardResponse(BaseModel): - prob: List - verdict: bool - sentence: List - latency: float = 0 +from src.core.model_utils import GuardRequest, GuardResponse class ArchGuardHanlder: diff --git a/model_server/src/core/base_handler.py b/model_server/src/core/model_utils.py similarity index 94% rename from model_server/src/core/base_handler.py rename to model_server/src/core/model_utils.py index c6b88749..c411d835 100644 --- a/model_server/src/core/base_handler.py +++ b/model_server/src/core/model_utils.py @@ -32,6 +32,21 @@ class ChatCompletionResponse(BaseModel): model: str +class GuardRequest(BaseModel): + input: str + task: str + + +class GuardResponse(BaseModel): + prob: List + verdict: bool + sentence: List + latency: float = 0 + + +# ================================================================================================ + + class ArchBaseHandler: def __init__( self, @@ -53,9 +68,7 @@ class ArchBaseHandler: format_prompt (str): A prompt specifying the desired output format. generation_params (Dict): Generation parameters for the model. """ - self.client = client - self.model_name = model_name self.task_prompt = task_prompt diff --git a/model_server/src/main.py b/model_server/src/main.py index 19fef239..0ca8c7c7 100644 --- a/model_server/src/main.py +++ b/model_server/src/main.py @@ -1,8 +1,7 @@ import os from src.commons.globals import handler_map -from src.core.base_handler import ChatMessage -from src.core.guardrails import GuardRequest +from src.core.model_utils import ChatMessage, GuardRequest from fastapi import FastAPI, Response from opentelemetry import trace diff --git a/model_server/tests/core/__init__.py b/model_server/tests/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model_server/tests/core/test_cases.json b/model_server/tests/core/test_cases.json new file mode 100644 index 00000000..229b4ea1 --- /dev/null +++ b/model_server/tests/core/test_cases.json @@ -0,0 +1,949 @@ +[ + { + "case": "tool_call_halluciation", + "tokens": [ + "" + ], + "expect": 1, + "logprobs": [ + [ + -0.3333307206630707, + -1.5310522317886353, + -3.5098977088928223, + -3.9004578590393066, + -5.775152683258057, + -5.814209461212158, + -5.9574151039123535, + -6.0094895362854, + -6.0094895362854, + -6.673445224761963 + ] + ] + }, + { + "case": "parameter_value_hallucination", + "expect": 0, + "tokens": [ + "", + "\n", + "{'", + "name", + "':", + " '", + "get", + "_current", + "_weather", + "',", + " '", + "arguments", + "':", + " {'", + "location", + "':", + " '", + "Sea", + ",", + " Australia", + "',", + " '", + "unit", + "':", + " '", + "c", + "elsius", + "',", + " '", + "days", + "':", + " '", + "1", + "'}}\n", + "" + ], + "logprobs": [ + [ + -0.008103232830762863, + -5.085402488708496, + -6.777836799621582, + -7.558959007263184, + -9.850253105163574, + -10.266852378845215, + -10.540244102478027, + -10.722506523132324, + -10.800618171691895, + -10.917786598205566 + ], + [ + 0.0, + -23.25142478942871, + -25.139137268066406, + -26.2847843170166, + -28.992677688598633, + -29.070789337158203, + -29.55248260498047, + -29.91700553894043, + -30.20341682434082, + -30.307567596435547 + ], + [ + 0.0, + -21.66313934326172, + -23.06916046142578, + -23.32953453063965, + -25.65988540649414, + -25.985353469848633, + -26.519121170043945, + -27.07892417907715, + -27.977216720581055, + -28.458908081054688 + ], + [ + 0.0, + -28.094383239746094, + -28.56305694580078, + -29.109844207763672, + -29.44832992553711, + -31.79170036315918, + -32.0, + -32.05207443237305, + -32.31244659423828, + -32.364524841308594 + ], + [ + 0.0, + -30.489830017089844, + -31.140766143798828, + -31.81774139404297, + -34.525634765625, + -35.8275032043457, + -36.504478454589844, + -39.05614471435547, + -40.123680114746094, + -40.696502685546875 + ], + [ + 0.0, + -25.646865844726562, + -26.66232681274414, + -27.781936645507812, + -28.979660034179688, + -31.140764236450195, + -31.92188835144043, + -31.973962783813477, + -33.04149627685547, + -33.58828353881836 + ], + [ + 0.0, + -23.511798858642578, + -24.136695861816406, + -25.230268478393555, + -25.777053833007812, + -25.80309295654297, + -26.45402717590332, + -26.636289596557617, + -26.740440368652344, + -26.896663665771484 + ], + [ + 0.0, + -22.366153717041016, + -24.683483123779297, + -26.610252380371094, + -26.610252380371094, + -27.313264846801758, + -27.67778778076172, + -28.510986328125, + -28.615135192871094, + -29.13588523864746 + ], + [ + 0.0, + -22.52237319946289, + -24.292919158935547, + -24.344993591308594, + -24.39706802368164, + -24.73555564880371, + -29.943042755126953, + -29.969079971313477, + -30.021154403686523, + -30.0341739654541 + ], + [ + 0.0, + -30.17738151550293, + -30.411718368530273, + -30.88039207458496, + -30.984540939331055, + -31.270952224731445, + -31.895851135253906, + -32.46867370605469, + -32.624900817871094, + -33.484134674072266 + ], + [ + 0.0, + -28.146459579467773, + -29.396255493164062, + -30.099267959594727, + -31.127744674682617, + -31.179821014404297, + -32.807159423828125, + -33.7445068359375, + -33.770545959472656, + -34.069976806640625 + ], + [ + 0.0, + -26.323841094970703, + -26.558177947998047, + -30.515867233276367, + -30.932466506958008, + -31.37510108947754, + -31.531326293945312, + -31.70056915283203, + -32.065093994140625, + -32.364524841308594 + ], + [ + 0.0, + -26.922698974609375, + -30.28152847290039, + -31.505287170410156, + -33.30187225341797, + -33.73148727416992, + -34.27827453613281, + -34.33034896850586, + -34.460533142089844, + -34.720909118652344 + ], + [ + 0.0, + -21.532955169677734, + -26.94873809814453, + -29.109848022460938, + -30.80228042602539, + -31.55736541748047, + -33.484134674072266, + -34.681854248046875, + -35.384864807128906, + -35.853538513183594 + ], + [ + 0.0, + -19.502033233642578, + -20.46541976928711, + -24.110658645629883, + -24.501218795776367, + -25.256305694580078, + -25.82912826538086, + -25.881202697753906, + -26.063465118408203, + -26.063465118408203 + ], + [ + 0.0, + -24.37103271484375, + -25.256305694580078, + -25.933277130126953, + -26.714401245117188, + -28.2506103515625, + -31.010576248168945, + -32.07810974121094, + -34.62977981567383, + -35.241661071777344 + ], + [ + -1.1920922133867862e-06, + -14.398697853088379, + -14.424736976623535, + -17.158666610717773, + -17.41904067993164, + -18.200162887573242, + -18.434499740600586, + -18.66883659362793, + -19.71033477783203, + -19.71033477783203 + ], + [ + -0.0001445904199499637, + -8.98305892944336, + -11.35246467590332, + -13.1490478515625, + -13.669795989990234, + -14.073375701904297, + -14.516012191772461, + -14.555068969726562, + -15.622602462768555, + -15.635622024536133 + ], + [ + -0.44747352600097656, + -1.0202960968017578, + -8.467000961303711, + -10.914518356323242, + -11.25300407409668, + -11.435266494750977, + -12.346576690673828, + -13.075624465942383, + -13.12769889831543, + -13.231849670410156 + ], + [ + -3.123767137527466, + -1.1188862323760986, + -1.639634370803833, + -2.0562336444854736, + -2.8633930683135986, + -2.9675419330596924, + -3.4882919788360596, + -3.69659161567688, + -4.217339515686035, + -4.243376731872559 + ], + [ + -7.199982064776123e-05, + -9.76410961151123, + -11.144091606140137, + -16.507802963256836, + -17.132701873779297, + -17.44515037536621, + -17.9138240814209, + -18.33042335510254, + -18.9162654876709, + -19.39795684814453 + ], + [ + 0.0, + -22.991050720214844, + -23.824249267578125, + -24.969894409179688, + -25.46460723876953, + -25.829130172729492, + -26.480066299438477, + -26.909683227539062, + -27.33930206298828, + -27.391376495361328 + ], + [ + -0.21928852796554565, + -1.625309705734253, + -9.775025367736816, + -12.977627754211426, + -16.388530731201172, + -17.091541290283203, + -19.044347763061523, + -19.38283348083496, + -19.460947036743164, + -19.59113311767578 + ], + [ + 0.0, + -24.006507873535156, + -27.443450927734375, + -27.729862213134766, + -28.12042236328125, + -28.276647567749023, + -28.927583694458008, + -30.099267959594727, + -31.479251861572266, + -32.07810974121094 + ], + [ + 0.0, + -18.17412567138672, + -18.772987365722656, + -21.689178466796875, + -21.92351531982422, + -23.7200984954834, + -23.79821014404297, + -23.79821014404297, + -24.032546997070312, + -25.308382034301758 + ], + [ + -0.12947827577590942, + -2.1083219051361084, + -12.419143676757812, + -15.23118782043457, + -15.595710754394531, + -15.830047607421875, + -17.001731872558594, + -17.60059356689453, + -18.121341705322266, + -18.251529693603516 + ], + [ + 0.0, + -19.449962615966797, + -24.371034622192383, + -24.917821884155273, + -25.529701232910156, + -25.85516929626465, + -26.037429809570312, + -26.115543365478516, + -26.623271942138672, + -26.649309158325195 + ], + [ + -0.03332124650478363, + -3.4181859493255615, + -15.759925842285156, + -15.812002182006836, + -16.593124389648438, + -17.894996643066406, + -18.09027671813965, + -18.79328727722168, + -19.144792556762695, + -20.147233963012695 + ], + [ + 0.0, + -21.142393112182617, + -22.157852172851562, + -23.511798858642578, + -24.657445907592773, + -25.021968841552734, + -25.5427188873291, + -25.59479331970215, + -25.75101661682129, + -25.95931625366211 + ], + [ + 0.0, + -23.04312515258789, + -24.94385528564453, + -26.323841094970703, + -27.54759979248047, + -28.563060760498047, + -29.786819458007812, + -30.620018005371094, + -30.69812774658203, + -31.08869171142578 + ], + [ + 0.0, + -26.167617797851562, + -28.771360397338867, + -29.55248260498047, + -30.906429290771484, + -31.114728927612305, + -31.414159774780273, + -31.622459411621094, + -31.713590621948242, + -31.726608276367188 + ], + [ + -0.05012698099017143, + -3.018392562866211, + -11.740934371948242, + -13.146955490112305, + -13.797887802124023, + -14.943536758422852, + -16.037107467651367, + -16.375595092773438, + -16.714080810546875, + -17.36501693725586 + ], + [ + -0.9704352021217346, + -0.7360983490943909, + -2.1941938400268555, + -4.225115776062012, + -5.0062360763549805, + -5.2666120529174805, + -5.839434623718262, + -7.2714948654174805, + -8.33902645111084, + -8.495253562927246 + ], + [ + -0.014467108063399792, + -4.258565902709961, + -8.789079666137695, + -10.429437637329102, + -10.793962478637695, + -11.835458755493164, + -11.939607620239258, + -13.31959342956543, + -13.866378784179688, + -15.038063049316406 + ], + [ + 0.0, + -20.08787727355957, + -21.350692749023438, + -21.415786743164062, + -21.50691795349121, + -21.50691795349121, + -22.7176570892334, + -24.13669776916504, + -24.188772201538086, + -24.34499740600586 + ] + ] + }, + { + "case": "fail_case", + "expect": 0, + "tokens": [ + "", + "\n", + "{'", + "name", + "':", + " '", + "get", + "_current", + "_weather", + "',", + " '", + "arguments", + "':", + " {'", + "location", + "':", + " '", + "Seattle", + ",", + " WA", + "',", + " '", + "unit", + "':", + " '", + "c", + "elsius", + "',", + " '", + "days", + "':", + " '", + "7", + "'}}\n", + "" + ], + "logprobs": [ + [ + -0.00013815402053296566, + -9.113236427307129, + -10.571331977844238, + -14.099404335021973, + -14.28166675567627, + -15.583537101745605, + -15.81787395477295, + -16.143341064453125, + -16.143341064453125, + -16.260509490966797 + ], + [ + 0.0, + -26.896663665771484, + -27.32628059387207, + -27.41741180419922, + -32.07810974121094, + -32.07810974121094, + -32.28641128540039, + -32.29943084716797, + -32.44263458251953, + -32.520748138427734 + ], + [ + 0.0, + -22.444263458251953, + -24.527257919311523, + -27.15703773498535, + -28.016273498535156, + -28.2506103515625, + -28.693246841430664, + -29.070789337158203, + -29.565500259399414, + -29.812854766845703 + ], + [ + 0.0, + -27.860050201416016, + -28.641170501708984, + -29.448333740234375, + -30.932466506958008, + -31.63547706604004, + -32.33848571777344, + -32.85923767089844, + -33.17168426513672, + -33.45809555053711 + ], + [ + 0.0, + -31.81774139404297, + -31.895854949951172, + -32.05207824707031, + -35.43694305419922, + -36.3482551574707, + -38.61351013183594, + -39.26444625854492, + -40.61839294433594, + -41.71196365356445 + ], + [ + 0.0, + -27.33930206298828, + -27.834014892578125, + -28.849472045898438, + -30.567943572998047, + -32.98942565917969, + -33.067535400390625, + -33.067535400390625, + -35.67127990722656, + -35.69731903076172 + ], + [ + 0.0, + -25.33441925048828, + -26.063465118408203, + -26.219690322875977, + -26.2457275390625, + -26.53213882446289, + -27.365337371826172, + -28.354759216308594, + -28.667207717895508, + -28.74532127380371 + ], + [ + 0.0, + -24.423107147216797, + -24.579330444335938, + -26.81855010986328, + -28.12042236328125, + -28.32872200012207, + -28.61513328552246, + -29.16191864013672, + -29.187957763671875, + -29.240032196044922 + ], + [ + 0.0, + -22.027664184570312, + -23.850284576416016, + -23.980472564697266, + -24.292922973632812, + -24.787633895874023, + -29.279088973999023, + -29.55248260498047, + -29.903987884521484, + -30.190399169921875 + ], + [ + 0.0, + -31.609439849853516, + -31.817739486694336, + -32.54678726196289, + -32.676971435546875, + -32.781124114990234, + -32.98942565917969, + -33.106590270996094, + -33.57526397705078, + -34.369407653808594 + ], + [ + 0.0, + -29.34418296813965, + -29.63059425354004, + -30.021156311035156, + -30.984540939331055, + -33.21073913574219, + -34.30431365966797, + -34.56468963623047, + -34.70789337158203, + -34.79902648925781 + ], + [ + 0.0, + -25.438566207885742, + -25.69894027709961, + -30.190397262573242, + -30.802276611328125, + -31.58340072631836, + -31.609437942504883, + -31.64849281311035, + -31.973960876464844, + -32.29943084716797 + ], + [ + 0.0, + -27.157039642333984, + -32.104148864746094, + -32.33848571777344, + -34.04393768310547, + -34.12205505371094, + -34.40846252441406, + -34.42148208618164, + -34.772987365722656, + -34.87713623046875 + ], + [ + 0.0, + -24.813671112060547, + -26.974777221679688, + -31.010578155517578, + -31.08869171142578, + -32.1822624206543, + -35.33279037475586, + -35.489013671875, + -36.999183654785156, + -37.88446044921875 + ], + [ + 0.0, + -20.46541976928711, + -20.647682189941406, + -23.069164276123047, + -24.136699676513672, + -25.438570022583008, + -25.646869659423828, + -26.193655014038086, + -26.297805786132812, + -26.506103515625 + ], + [ + 0.0, + -27.18307113647461, + -28.30268096923828, + -28.56305694580078, + -29.526439666748047, + -32.416595458984375, + -35.202598571777344, + -36.426361083984375, + -39.31651306152344, + -39.38160705566406 + ], + [ + 0.0, + -18.7469482421875, + -20.100894927978516, + -21.402767181396484, + -21.428804397583008, + -22.20992660522461, + -22.34011459350586, + -22.730674743652344, + -23.069162368774414, + -23.980472564697266 + ], + [ + -3.576278118089249e-07, + -15.2579345703125, + -16.481693267822266, + -17.991863250732422, + -19.215621948242188, + -20.25712013244629, + -21.350692749023438, + -22.314077377319336, + -22.496337890625, + -22.938974380493164 + ], + [ + -0.08506780862808228, + -2.506549835205078, + -14.848289489746094, + -15.473188400268555, + -16.33242416381836, + -16.358461380004883, + -16.566761016845703, + -17.03543472290039, + -17.686370849609375, + -17.816556930541992 + ], + [ + -0.0194891095161438, + -4.445854187011719, + -5.591499328613281, + -5.956024169921875, + -6.685070037841797, + -13.142353057861328, + -13.558952331542969, + -15.173273086547852, + -15.303461074829102, + -15.85024642944336 + ], + [ + -0.0005990855861455202, + -7.4212646484375, + -15.675132751464844, + -15.72720718383789, + -16.76870346069336, + -16.76870346069336, + -17.706050872802734, + -18.669435501098633, + -19.398483276367188, + -19.658857345581055 + ], + [ + 0.0, + -24.110658645629883, + -25.829130172729492, + -26.011390686035156, + -26.011390686035156, + -26.532140731811523, + -26.58421516418457, + -27.651750564575195, + -27.75589942932129, + -28.055330276489258 + ], + [ + -1.1408883333206177, + -0.38580334186553955, + -7.494022369384766, + -12.519245147705078, + -14.576202392578125, + -16.034297943115234, + -16.945608139038086, + -17.908992767333984, + -18.664077758789062, + -19.34105110168457 + ], + [ + 0.0, + -26.688365936279297, + -29.83889389038086, + -30.177383422851562, + -30.64605712890625, + -31.244916915893555, + -31.270954132080078, + -32.83319854736328, + -34.655818939208984, + -34.89015579223633 + ], + [ + 0.0, + -18.929210662841797, + -19.16354751586914, + -23.589908599853516, + -24.683481216430664, + -24.995929718017578, + -25.516677856445312, + -25.542715072631836, + -25.77705192565918, + -26.063465118408203 + ], + [ + -0.2519786059856415, + -1.5017764568328857, + -12.437495231628418, + -15.457839012145996, + -15.744250297546387, + -16.837820053100586, + -17.41064453125, + -17.56686782836914, + -17.61894416809082, + -18.035541534423828 + ], + [ + 0.0, + -20.517494201660156, + -24.683483123779297, + -25.67290496826172, + -26.58421516418457, + -27.651750564575195, + -27.781936645507812, + -27.912124633789062, + -28.09438705444336, + -28.445892333984375 + ], + [ + -3.40932747349143e-05, + -10.284820556640625, + -18.252273559570312, + -20.17904281616211, + -21.663175582885742, + -22.027700424194336, + -22.288074493408203, + -22.704673767089844, + -23.12127113342285, + -23.277496337890625 + ], + [ + 0.0, + -22.60049057006836, + -25.46460723876953, + -25.829130172729492, + -26.063467025756836, + -27.287227630615234, + -27.391376495361328, + -27.4694881439209, + -27.67778778076172, + -28.055330276489258 + ], + [ + 0.0, + -23.902362823486328, + -28.823436737060547, + -29.240036010742188, + -29.31814956665039, + -29.917007446289062, + -30.021160125732422, + -31.21887969970703, + -32.416603088378906, + -32.416603088378906 + ], + [ + 0.0, + -28.641170501708984, + -31.947925567626953, + -32.59886169433594, + -33.848655700683594, + -34.109031677246094, + -34.73393249511719, + -35.02033996582031, + -35.02033996582031, + -36.074859619140625 + ], + [ + -0.013183215633034706, + -4.335395336151123, + -19.619365692138672, + -20.035964965820312, + -20.244266510009766, + -21.311800003051758, + -21.441987991333008, + -22.561595916748047, + -23.108383178710938, + -23.264606475830078 + ], + [ + -8.344646857949556e-07, + -14.190400123596191, + -15.9088716506958, + -18.17412567138672, + -18.46053695678711, + -18.46053695678711, + -18.512611389160156, + -18.90317153930664, + -19.059398651123047, + -19.085433959960938 + ], + [ + 0.0, + -17.70545196533203, + -18.903175354003906, + -20.829944610595703, + -22.574451446533203, + -22.860862731933594, + -23.069162368774414, + -23.32953643798828, + -23.694061279296875, + -24.188772201538086 + ], + [ + 0.0, + -20.022781372070312, + -21.038240432739258, + -21.220502853393555, + -22.496337890625, + -22.769729614257812, + -23.589908599853516, + -23.65500259399414, + -23.94141387939453, + -24.266881942749023 + ] + ] + } +] diff --git a/model_server/tests/core/test_function_calling.py b/model_server/tests/core/test_function_calling.py index c59a28dc..2c786809 100644 --- a/model_server/tests/core/test_function_calling.py +++ b/model_server/tests/core/test_function_calling.py @@ -4,7 +4,7 @@ import pytest from fastapi import Response from unittest.mock import AsyncMock, MagicMock, patch from src.commons.globals import handler_map -from src.core.base_handler import ( +from src.core.model_utils import ( Message, ChatMessage, ChatCompletionResponse, @@ -47,7 +47,7 @@ def sample_request(sample_messages): ) -@patch("app.commons.globals.handler_map") +@patch("src.commons.globals.handler_map") def test_process_messages(mock_hanlder): messages = sample_messages() processed = handler_map["Arch-Function"]._process_messages(messages) @@ -68,39 +68,39 @@ def test_process_messages(mock_hanlder): # [TODO] Review: Update the following test -@patch("app.commons.constants.arch_function_client") -@patch("app.commons.constants.arch_function_hanlder") -@pytest.mark.asyncio -async def test_chat_completion(mock_hanlder, mock_client): - # Mock the model list return for client - mock_client.models.list.return_value = MagicMock( - data=[MagicMock(id="sample_model")] - ) - request = sample_request(sample_messages()) - # Simulate stream response as list of tokens - mock_response = AsyncMock() - mock_response.__aiter__.return_value = [ - MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]), - MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream - ] - mock_client.chat.completions.create.return_value = mock_response +# @patch("src.commons.globals.ARCH_CLIENT") +# @patch("src.commons.constants.handler_map") +# @pytest.mark.asyncio +# async def test_chat_completion(mock_hanlder, mock_client): +# # Mock the model list return for client +# mock_client.models.list.return_value = MagicMock( +# data=[MagicMock(id="sample_model")] +# ) +# request = sample_request(sample_messages()) +# # Simulate stream response as list of tokens +# mock_response = AsyncMock() +# mock_response.__aiter__.return_value = [ +# MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]), +# MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream +# ] +# mock_client.chat.completions.create.return_value = mock_response - # Mock the tool formatter - mock_hanlder._format_system_prompt.return_value = "" +# # Mock the tool formatter +# mock_hanlder._format_system_prompt.return_value = "" - response = Response() - chat_response = await chat_completion(request, response) +# response = Response() +# chat_response = await chat_completion(request, response) - assert isinstance(chat_response, ChatCompletionResponse) - assert chat_response.choices[0].message.content is not None +# assert isinstance(chat_response, ChatCompletionResponse) +# assert chat_response.choices[0].message.content is not None - first_call_args = mock_client.chat.completions.create.call_args_list[0][1] - assert first_call_args["stream"] == True - assert "model" in first_call_args - assert first_call_args["messages"][0]["content"] == "" +# first_call_args = mock_client.chat.completions.create.call_args_list[0][1] +# assert first_call_args["stream"] == True +# assert "model" in first_call_args +# assert first_call_args["messages"][0]["content"] == "" - # Check that the arguments for the second call to 'create' include the pre-fill completion - second_call_args = mock_client.chat.completions.create.call_args_list[1][1] - assert second_call_args["stream"] == False - assert "model" in second_call_args - assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST +# # Check that the arguments for the second call to 'create' include the pre-fill completion +# second_call_args = mock_client.chat.completions.create.call_args_list[1][1] +# assert second_call_args["stream"] == False +# assert "model" in second_call_args +# assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST diff --git a/model_server/tests/core/test_guardrails.py b/model_server/tests/core/test_guardrails.py index 5ba7ad11..b2087580 100644 --- a/model_server/tests/core/test_guardrails.py +++ b/model_server/tests/core/test_guardrails.py @@ -11,9 +11,9 @@ arch_guard_model_type = { # [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps` # Test for `get_guardrail_handler()` function on `cpu` -@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained") -@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained") -@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained") +@patch("src.core.guardrail.AutoTokenizer.from_pretrained") +@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained") +@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained") def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer): device = "cpu" @@ -34,9 +34,9 @@ def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer # Test for `get_guardrail_handler()` function on `cuda` -@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained") -@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained") -@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained") +@patch("src.core.guardrail.AutoTokenizer.from_pretrained") +@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained") +@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained") def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer): device = "cuda" @@ -57,9 +57,9 @@ def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenize # Test for `get_guardrail_handler()` function on `mps` -@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained") -@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained") -@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained") +@patch("src.core.guardrail.AutoTokenizer.from_pretrained") +@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained") +@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained") def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer): device = "mps"