Refacotr model configs

This commit is contained in:
Shuguang Chen 2024-12-08 15:43:19 -08:00
parent 320f4612b8
commit 95e167c2f6
12 changed files with 1144 additions and 206 deletions

View file

@ -8,8 +8,8 @@ from src.commons.utils import (
wait_for_health_check,
check_lsof,
install_lsof,
find_process_by_port,
kill_process_by_port,
find_processes_by_port,
kill_processes,
)
@ -23,7 +23,7 @@ def start_server(port=51000):
"python",
"-m",
"uvicorn",
"app.main:app",
"src.main:app",
"--host",
"0.0.0.0",
"--port",
@ -56,14 +56,16 @@ def stop_server(port=51000, wait=True, timeout=10):
sys.exit(1)
logger.info(f"Stopping processes on port {port}...")
port_processes = find_process_by_port(port)
port_processes = find_processes_by_port(port)
if port_processes is None:
logger.info(f"No processes found listening on port {port}.")
else:
if len(port_processes):
process_killed = kill_process_by_port(port_processes, wait, timeout)
process_killed = kill_processes(port_processes, wait, timeout)
if not process_killed:
logger.error(f"Unable to kill all processes on {port}")
else:
logger.info(f"All processes on port {port} have been killed.")
else:
logger.error(f"Unable to find processes on {port}")

View file

@ -1,79 +0,0 @@
# ========================== Arch-Intent Default Params ==========================
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
ARCH_INTENT_INSTRUCTION = "Are there any tools can help?"
ARCH_INTENT_TASK_PROMPT = """
You are a helpful assistant.
"""
ARCH_INTENT_TOOL_PROMPT_TEMPLATE = """
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
<tools>
{tool_text}
</tools>
"""
ARCH_INTENT_FORMAT_PROMPT = """
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
- First line must read 'Yes' or 'No'.
- If yes, a second line must include a comma-separated list of tool indexes.
"""
ARCH_INTENT_GENERATION_CONFIG = {
"generation_params": {"max_tokens": 1, "stop_token_ids": [151645]}
}
# ========================== Arch-Function Default Params ==========================
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
ARCH_FUNCTION_TASK_PROMPT = """
You are a helpful assistant.
"""
ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE = """
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
"""
ARCH_FUNCTION_FORMAT_PROMPT = """
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
"""
ARCH_FUNCTION_GENERATION_CONFIG = {
"generation_params": {
"temperature": 0.2,
"top_p": 1.0,
"top_k": 50,
"max_tokens": 512,
"stop_token_ids": [151645],
},
"prefill_params": {
"continue_final_message": True,
"add_generation_prompt": False,
},
"prefill_prefix": [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
],
}

View file

@ -1,36 +1,36 @@
from openai import OpenAI
from src.commons.constants import *
from src.core.function_calling import ArchIntentHandler, ArchFunctionHandler
from src.core.guardrails import get_guardrail_handler
from src.commons.utils import get_model_server_logger
from src.core.guardrails import get_guardrail_handler
from src.core.function_calling import (
ArchIntentConfig,
ArchIntentHandler,
ArchFunctionConfig,
ArchFunctionHandler,
)
# Define logger
logger = get_model_server_logger()
# Define the client
ARCH_ENDPOINT = "https://api.fc.archgw.com/v1"
ARCH_API_KEY = "EMPTY"
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
# Define model names
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
# Define model handlers
handler_map = {
"Arch-Intent": ArchIntentHandler(
ARCH_CLIENT,
ARCH_INTENT_MODEL_ALIAS,
ARCH_INTENT_TASK_PROMPT,
ARCH_INTENT_TOOL_PROMPT_TEMPLATE,
ARCH_INTENT_FORMAT_PROMPT,
ARCH_INTENT_INSTRUCTION,
**ARCH_INTENT_GENERATION_CONFIG,
ARCH_CLIENT, ARCH_INTENT_MODEL_ALIAS, ArchIntentConfig
),
"Arch-Function": ArchFunctionHandler(
ARCH_CLIENT,
ARCH_FUNCTION_MODEL_ALIAS,
ARCH_FUNCTION_TASK_PROMPT,
ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE,
ARCH_FUNCTION_FORMAT_PROMPT,
**ARCH_FUNCTION_GENERATION_CONFIG,
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
),
"Arch-Guard": get_guardrail_handler(),
}

View file

@ -12,7 +12,7 @@ PROJ_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fil
# Default log directory and file
DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, "logs")
DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, ".logs")
DEFAULT_LOG_FILE = "modelserver.log"
@ -50,7 +50,7 @@ def get_model_server_logger(log_dir=None, log_file=None):
Get or initialize the logger instance for the model server.
Parameters:
- log_dir (str): Custom directory to store the log file. Defaults to `~/archgw_logs`.
- log_dir (str): Custom directory to store the log file. Defaults to `./.logs`.
- log_file (str): Custom log file name. Defaults to `modelserver.log`.
Returns:
@ -146,13 +146,13 @@ def terminate_process_by_pid(pid, timeout):
subprocess.run(["kill", "-9", str(pid)], check=False)
def find_process_by_port(port=51000):
def find_processes_by_port(port=51000):
"""Find processes listening on a specific port."""
port_processes = []
try:
lsof_command = f"lsof -n -i:{port} | grep LISTEN"
lsof_command = f"lsof -n | grep {port} | grep -i LISTEN"
result = subprocess.run(
lsof_command, shell=True, capture_output=True, text=True
)
@ -167,7 +167,7 @@ def find_process_by_port(port=51000):
return []
def kill_process_by_port(port_processes=51000, wait=True, timeout=10):
def kill_processes(port_processes, wait=True, timeout=10):
"""Kill processes on a specific port."""
try:

View file

@ -1,11 +1,12 @@
import json
import random
import builtins
import textwrap
from openai import OpenAI
from typing import Any, Dict, List, Tuple, Union
from overrides import override
from src.core.base_handler import (
from src.core.model_utils import (
Message,
ChatMessage,
Choice,
@ -14,43 +15,57 @@ from src.core.base_handler import (
)
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
class ArchIntentConfig:
TASK_PROMPT = textwrap.dedent(
"""
You are a helpful assistant.
"""
).strip()
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
"""
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
<tools>
{tool_text}
</tools>
"""
).strip()
FORMAT_PROMPT = textwrap.dedent(
"""
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
- First line must read 'Yes' or 'No'.
- If yes, a second line must include a comma-separated list of tool indexes.
"""
).strip()
EXTRA_INSTRUCTION = "Are there any tools can help?"
GENERATION_PARAMS = {"max_tokens": 1, "stop_token_ids": [151645]}
class ArchIntentHandler(ArchBaseHandler):
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt_template: str,
format_prompt: str,
extra_instruction: str,
generation_params: Dict,
):
def __init__(self, client: OpenAI, model_name: str, config: ArchIntentConfig):
"""
Initializes the intent handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt_template (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
extra_instruction (str): Instructions specific to intent handling.
generation_params (Dict): Generation parameters for the model.
config (ArchIntentConfig): The configuration for Arch-Intent.
"""
super().__init__(
client,
model_name,
task_prompt,
tool_prompt_template,
format_prompt,
generation_params,
config.TASK_PROMPT,
config.TOOL_PROMPT_TEMPLATE,
config.FORMAT_PROMPT,
config.GENERATION_PARAMS,
)
self.extra_instruction = extra_instruction
self.extra_instruction = config.EXTRA_INSTRUCTION
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
@ -125,17 +140,73 @@ class ArchIntentHandler(ArchBaseHandler):
return chat_completion_response
# =============================================================================================================
class ArchFunctionConfig:
TASK_PROMPT = textwrap.dedent(
"""
You are a helpful assistant.
"""
).strip()
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
"""
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
"""
).strip()
FORMAT_PROMPT = textwrap.dedent(
"""
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
"""
).strip()
GENERATION_PARAMS = (
{
"temperature": 0.2,
"top_p": 1.0,
"top_k": 50,
"max_tokens": 512,
"stop_token_ids": [151645],
},
)
PREFILL_CONFIG = {
"prefill_params": {
"continue_final_message": True,
"add_generation_prompt": False,
},
"prefill_prefix": [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
],
}
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
class ArchFunctionHandler(ArchBaseHandler):
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt_template: str,
format_prompt: str,
generation_params: Dict,
prefill_params: Dict,
prefill_prefix: List,
config: ArchFunctionConfig,
):
"""
Initializes the function handler.
@ -143,30 +214,26 @@ class ArchFunctionHandler(ArchBaseHandler):
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt_template (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
prefill_params (Dict): Additional parameters for prefilling responses.
prefill_prefix (List[str]): List of prefixes for prefill responses.
config (ArchFunctionConfig): The configuration for Arch-Function
"""
super().__init__(
client,
model_name,
task_prompt,
tool_prompt_template,
format_prompt,
generation_params,
config.TASK_PROMPT,
config.TOOL_PROMPT_TEMPLATE,
config.FORMAT_PROMPT,
config.GENERATION_PARAMS,
)
self.prefill_params = prefill_params
self.prefill_prefix = prefill_prefix
self.prefill_params = config.PREFILL_CONFIG["prefill_params"]
self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"]
# Predefine data types for verification. Only support Python for now.
# [TODO] Extend the list of support data types
self.support_data_types = {
type_name: getattr(builtins, type_name) for type_name in SUPPORT_DATA_TYPES
type_name: getattr(builtins, type_name)
for type_name in config.SUPPORT_DATA_TYPES
}
@override

View file

@ -3,22 +3,9 @@ import torch
import numpy as np
import src.commons.utils as utils
from typing import List
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
class GuardRequest(BaseModel):
input: str
task: str
class GuardResponse(BaseModel):
prob: List
verdict: bool
sentence: List
latency: float = 0
from src.core.model_utils import GuardRequest, GuardResponse
class ArchGuardHanlder:

View file

@ -32,6 +32,21 @@ class ChatCompletionResponse(BaseModel):
model: str
class GuardRequest(BaseModel):
input: str
task: str
class GuardResponse(BaseModel):
prob: List
verdict: bool
sentence: List
latency: float = 0
# ================================================================================================
class ArchBaseHandler:
def __init__(
self,
@ -53,9 +68,7 @@ class ArchBaseHandler:
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
"""
self.client = client
self.model_name = model_name
self.task_prompt = task_prompt

View file

@ -1,8 +1,7 @@
import os
from src.commons.globals import handler_map
from src.core.base_handler import ChatMessage
from src.core.guardrails import GuardRequest
from src.core.model_utils import ChatMessage, GuardRequest
from fastapi import FastAPI, Response
from opentelemetry import trace

View file

View file

@ -0,0 +1,949 @@
[
{
"case": "tool_call_halluciation",
"tokens": [
"<tool_call>"
],
"expect": 1,
"logprobs": [
[
-0.3333307206630707,
-1.5310522317886353,
-3.5098977088928223,
-3.9004578590393066,
-5.775152683258057,
-5.814209461212158,
-5.9574151039123535,
-6.0094895362854,
-6.0094895362854,
-6.673445224761963
]
]
},
{
"case": "parameter_value_hallucination",
"expect": 0,
"tokens": [
"<tool_call>",
"\n",
"{'",
"name",
"':",
" '",
"get",
"_current",
"_weather",
"',",
" '",
"arguments",
"':",
" {'",
"location",
"':",
" '",
"Sea",
",",
" Australia",
"',",
" '",
"unit",
"':",
" '",
"c",
"elsius",
"',",
" '",
"days",
"':",
" '",
"1",
"'}}\n",
"</tool_call>"
],
"logprobs": [
[
-0.008103232830762863,
-5.085402488708496,
-6.777836799621582,
-7.558959007263184,
-9.850253105163574,
-10.266852378845215,
-10.540244102478027,
-10.722506523132324,
-10.800618171691895,
-10.917786598205566
],
[
0.0,
-23.25142478942871,
-25.139137268066406,
-26.2847843170166,
-28.992677688598633,
-29.070789337158203,
-29.55248260498047,
-29.91700553894043,
-30.20341682434082,
-30.307567596435547
],
[
0.0,
-21.66313934326172,
-23.06916046142578,
-23.32953453063965,
-25.65988540649414,
-25.985353469848633,
-26.519121170043945,
-27.07892417907715,
-27.977216720581055,
-28.458908081054688
],
[
0.0,
-28.094383239746094,
-28.56305694580078,
-29.109844207763672,
-29.44832992553711,
-31.79170036315918,
-32.0,
-32.05207443237305,
-32.31244659423828,
-32.364524841308594
],
[
0.0,
-30.489830017089844,
-31.140766143798828,
-31.81774139404297,
-34.525634765625,
-35.8275032043457,
-36.504478454589844,
-39.05614471435547,
-40.123680114746094,
-40.696502685546875
],
[
0.0,
-25.646865844726562,
-26.66232681274414,
-27.781936645507812,
-28.979660034179688,
-31.140764236450195,
-31.92188835144043,
-31.973962783813477,
-33.04149627685547,
-33.58828353881836
],
[
0.0,
-23.511798858642578,
-24.136695861816406,
-25.230268478393555,
-25.777053833007812,
-25.80309295654297,
-26.45402717590332,
-26.636289596557617,
-26.740440368652344,
-26.896663665771484
],
[
0.0,
-22.366153717041016,
-24.683483123779297,
-26.610252380371094,
-26.610252380371094,
-27.313264846801758,
-27.67778778076172,
-28.510986328125,
-28.615135192871094,
-29.13588523864746
],
[
0.0,
-22.52237319946289,
-24.292919158935547,
-24.344993591308594,
-24.39706802368164,
-24.73555564880371,
-29.943042755126953,
-29.969079971313477,
-30.021154403686523,
-30.0341739654541
],
[
0.0,
-30.17738151550293,
-30.411718368530273,
-30.88039207458496,
-30.984540939331055,
-31.270952224731445,
-31.895851135253906,
-32.46867370605469,
-32.624900817871094,
-33.484134674072266
],
[
0.0,
-28.146459579467773,
-29.396255493164062,
-30.099267959594727,
-31.127744674682617,
-31.179821014404297,
-32.807159423828125,
-33.7445068359375,
-33.770545959472656,
-34.069976806640625
],
[
0.0,
-26.323841094970703,
-26.558177947998047,
-30.515867233276367,
-30.932466506958008,
-31.37510108947754,
-31.531326293945312,
-31.70056915283203,
-32.065093994140625,
-32.364524841308594
],
[
0.0,
-26.922698974609375,
-30.28152847290039,
-31.505287170410156,
-33.30187225341797,
-33.73148727416992,
-34.27827453613281,
-34.33034896850586,
-34.460533142089844,
-34.720909118652344
],
[
0.0,
-21.532955169677734,
-26.94873809814453,
-29.109848022460938,
-30.80228042602539,
-31.55736541748047,
-33.484134674072266,
-34.681854248046875,
-35.384864807128906,
-35.853538513183594
],
[
0.0,
-19.502033233642578,
-20.46541976928711,
-24.110658645629883,
-24.501218795776367,
-25.256305694580078,
-25.82912826538086,
-25.881202697753906,
-26.063465118408203,
-26.063465118408203
],
[
0.0,
-24.37103271484375,
-25.256305694580078,
-25.933277130126953,
-26.714401245117188,
-28.2506103515625,
-31.010576248168945,
-32.07810974121094,
-34.62977981567383,
-35.241661071777344
],
[
-1.1920922133867862e-06,
-14.398697853088379,
-14.424736976623535,
-17.158666610717773,
-17.41904067993164,
-18.200162887573242,
-18.434499740600586,
-18.66883659362793,
-19.71033477783203,
-19.71033477783203
],
[
-0.0001445904199499637,
-8.98305892944336,
-11.35246467590332,
-13.1490478515625,
-13.669795989990234,
-14.073375701904297,
-14.516012191772461,
-14.555068969726562,
-15.622602462768555,
-15.635622024536133
],
[
-0.44747352600097656,
-1.0202960968017578,
-8.467000961303711,
-10.914518356323242,
-11.25300407409668,
-11.435266494750977,
-12.346576690673828,
-13.075624465942383,
-13.12769889831543,
-13.231849670410156
],
[
-3.123767137527466,
-1.1188862323760986,
-1.639634370803833,
-2.0562336444854736,
-2.8633930683135986,
-2.9675419330596924,
-3.4882919788360596,
-3.69659161567688,
-4.217339515686035,
-4.243376731872559
],
[
-7.199982064776123e-05,
-9.76410961151123,
-11.144091606140137,
-16.507802963256836,
-17.132701873779297,
-17.44515037536621,
-17.9138240814209,
-18.33042335510254,
-18.9162654876709,
-19.39795684814453
],
[
0.0,
-22.991050720214844,
-23.824249267578125,
-24.969894409179688,
-25.46460723876953,
-25.829130172729492,
-26.480066299438477,
-26.909683227539062,
-27.33930206298828,
-27.391376495361328
],
[
-0.21928852796554565,
-1.625309705734253,
-9.775025367736816,
-12.977627754211426,
-16.388530731201172,
-17.091541290283203,
-19.044347763061523,
-19.38283348083496,
-19.460947036743164,
-19.59113311767578
],
[
0.0,
-24.006507873535156,
-27.443450927734375,
-27.729862213134766,
-28.12042236328125,
-28.276647567749023,
-28.927583694458008,
-30.099267959594727,
-31.479251861572266,
-32.07810974121094
],
[
0.0,
-18.17412567138672,
-18.772987365722656,
-21.689178466796875,
-21.92351531982422,
-23.7200984954834,
-23.79821014404297,
-23.79821014404297,
-24.032546997070312,
-25.308382034301758
],
[
-0.12947827577590942,
-2.1083219051361084,
-12.419143676757812,
-15.23118782043457,
-15.595710754394531,
-15.830047607421875,
-17.001731872558594,
-17.60059356689453,
-18.121341705322266,
-18.251529693603516
],
[
0.0,
-19.449962615966797,
-24.371034622192383,
-24.917821884155273,
-25.529701232910156,
-25.85516929626465,
-26.037429809570312,
-26.115543365478516,
-26.623271942138672,
-26.649309158325195
],
[
-0.03332124650478363,
-3.4181859493255615,
-15.759925842285156,
-15.812002182006836,
-16.593124389648438,
-17.894996643066406,
-18.09027671813965,
-18.79328727722168,
-19.144792556762695,
-20.147233963012695
],
[
0.0,
-21.142393112182617,
-22.157852172851562,
-23.511798858642578,
-24.657445907592773,
-25.021968841552734,
-25.5427188873291,
-25.59479331970215,
-25.75101661682129,
-25.95931625366211
],
[
0.0,
-23.04312515258789,
-24.94385528564453,
-26.323841094970703,
-27.54759979248047,
-28.563060760498047,
-29.786819458007812,
-30.620018005371094,
-30.69812774658203,
-31.08869171142578
],
[
0.0,
-26.167617797851562,
-28.771360397338867,
-29.55248260498047,
-30.906429290771484,
-31.114728927612305,
-31.414159774780273,
-31.622459411621094,
-31.713590621948242,
-31.726608276367188
],
[
-0.05012698099017143,
-3.018392562866211,
-11.740934371948242,
-13.146955490112305,
-13.797887802124023,
-14.943536758422852,
-16.037107467651367,
-16.375595092773438,
-16.714080810546875,
-17.36501693725586
],
[
-0.9704352021217346,
-0.7360983490943909,
-2.1941938400268555,
-4.225115776062012,
-5.0062360763549805,
-5.2666120529174805,
-5.839434623718262,
-7.2714948654174805,
-8.33902645111084,
-8.495253562927246
],
[
-0.014467108063399792,
-4.258565902709961,
-8.789079666137695,
-10.429437637329102,
-10.793962478637695,
-11.835458755493164,
-11.939607620239258,
-13.31959342956543,
-13.866378784179688,
-15.038063049316406
],
[
0.0,
-20.08787727355957,
-21.350692749023438,
-21.415786743164062,
-21.50691795349121,
-21.50691795349121,
-22.7176570892334,
-24.13669776916504,
-24.188772201538086,
-24.34499740600586
]
]
},
{
"case": "fail_case",
"expect": 0,
"tokens": [
"<tool_call>",
"\n",
"{'",
"name",
"':",
" '",
"get",
"_current",
"_weather",
"',",
" '",
"arguments",
"':",
" {'",
"location",
"':",
" '",
"Seattle",
",",
" WA",
"',",
" '",
"unit",
"':",
" '",
"c",
"elsius",
"',",
" '",
"days",
"':",
" '",
"7",
"'}}\n",
"</tool_call>"
],
"logprobs": [
[
-0.00013815402053296566,
-9.113236427307129,
-10.571331977844238,
-14.099404335021973,
-14.28166675567627,
-15.583537101745605,
-15.81787395477295,
-16.143341064453125,
-16.143341064453125,
-16.260509490966797
],
[
0.0,
-26.896663665771484,
-27.32628059387207,
-27.41741180419922,
-32.07810974121094,
-32.07810974121094,
-32.28641128540039,
-32.29943084716797,
-32.44263458251953,
-32.520748138427734
],
[
0.0,
-22.444263458251953,
-24.527257919311523,
-27.15703773498535,
-28.016273498535156,
-28.2506103515625,
-28.693246841430664,
-29.070789337158203,
-29.565500259399414,
-29.812854766845703
],
[
0.0,
-27.860050201416016,
-28.641170501708984,
-29.448333740234375,
-30.932466506958008,
-31.63547706604004,
-32.33848571777344,
-32.85923767089844,
-33.17168426513672,
-33.45809555053711
],
[
0.0,
-31.81774139404297,
-31.895854949951172,
-32.05207824707031,
-35.43694305419922,
-36.3482551574707,
-38.61351013183594,
-39.26444625854492,
-40.61839294433594,
-41.71196365356445
],
[
0.0,
-27.33930206298828,
-27.834014892578125,
-28.849472045898438,
-30.567943572998047,
-32.98942565917969,
-33.067535400390625,
-33.067535400390625,
-35.67127990722656,
-35.69731903076172
],
[
0.0,
-25.33441925048828,
-26.063465118408203,
-26.219690322875977,
-26.2457275390625,
-26.53213882446289,
-27.365337371826172,
-28.354759216308594,
-28.667207717895508,
-28.74532127380371
],
[
0.0,
-24.423107147216797,
-24.579330444335938,
-26.81855010986328,
-28.12042236328125,
-28.32872200012207,
-28.61513328552246,
-29.16191864013672,
-29.187957763671875,
-29.240032196044922
],
[
0.0,
-22.027664184570312,
-23.850284576416016,
-23.980472564697266,
-24.292922973632812,
-24.787633895874023,
-29.279088973999023,
-29.55248260498047,
-29.903987884521484,
-30.190399169921875
],
[
0.0,
-31.609439849853516,
-31.817739486694336,
-32.54678726196289,
-32.676971435546875,
-32.781124114990234,
-32.98942565917969,
-33.106590270996094,
-33.57526397705078,
-34.369407653808594
],
[
0.0,
-29.34418296813965,
-29.63059425354004,
-30.021156311035156,
-30.984540939331055,
-33.21073913574219,
-34.30431365966797,
-34.56468963623047,
-34.70789337158203,
-34.79902648925781
],
[
0.0,
-25.438566207885742,
-25.69894027709961,
-30.190397262573242,
-30.802276611328125,
-31.58340072631836,
-31.609437942504883,
-31.64849281311035,
-31.973960876464844,
-32.29943084716797
],
[
0.0,
-27.157039642333984,
-32.104148864746094,
-32.33848571777344,
-34.04393768310547,
-34.12205505371094,
-34.40846252441406,
-34.42148208618164,
-34.772987365722656,
-34.87713623046875
],
[
0.0,
-24.813671112060547,
-26.974777221679688,
-31.010578155517578,
-31.08869171142578,
-32.1822624206543,
-35.33279037475586,
-35.489013671875,
-36.999183654785156,
-37.88446044921875
],
[
0.0,
-20.46541976928711,
-20.647682189941406,
-23.069164276123047,
-24.136699676513672,
-25.438570022583008,
-25.646869659423828,
-26.193655014038086,
-26.297805786132812,
-26.506103515625
],
[
0.0,
-27.18307113647461,
-28.30268096923828,
-28.56305694580078,
-29.526439666748047,
-32.416595458984375,
-35.202598571777344,
-36.426361083984375,
-39.31651306152344,
-39.38160705566406
],
[
0.0,
-18.7469482421875,
-20.100894927978516,
-21.402767181396484,
-21.428804397583008,
-22.20992660522461,
-22.34011459350586,
-22.730674743652344,
-23.069162368774414,
-23.980472564697266
],
[
-3.576278118089249e-07,
-15.2579345703125,
-16.481693267822266,
-17.991863250732422,
-19.215621948242188,
-20.25712013244629,
-21.350692749023438,
-22.314077377319336,
-22.496337890625,
-22.938974380493164
],
[
-0.08506780862808228,
-2.506549835205078,
-14.848289489746094,
-15.473188400268555,
-16.33242416381836,
-16.358461380004883,
-16.566761016845703,
-17.03543472290039,
-17.686370849609375,
-17.816556930541992
],
[
-0.0194891095161438,
-4.445854187011719,
-5.591499328613281,
-5.956024169921875,
-6.685070037841797,
-13.142353057861328,
-13.558952331542969,
-15.173273086547852,
-15.303461074829102,
-15.85024642944336
],
[
-0.0005990855861455202,
-7.4212646484375,
-15.675132751464844,
-15.72720718383789,
-16.76870346069336,
-16.76870346069336,
-17.706050872802734,
-18.669435501098633,
-19.398483276367188,
-19.658857345581055
],
[
0.0,
-24.110658645629883,
-25.829130172729492,
-26.011390686035156,
-26.011390686035156,
-26.532140731811523,
-26.58421516418457,
-27.651750564575195,
-27.75589942932129,
-28.055330276489258
],
[
-1.1408883333206177,
-0.38580334186553955,
-7.494022369384766,
-12.519245147705078,
-14.576202392578125,
-16.034297943115234,
-16.945608139038086,
-17.908992767333984,
-18.664077758789062,
-19.34105110168457
],
[
0.0,
-26.688365936279297,
-29.83889389038086,
-30.177383422851562,
-30.64605712890625,
-31.244916915893555,
-31.270954132080078,
-32.83319854736328,
-34.655818939208984,
-34.89015579223633
],
[
0.0,
-18.929210662841797,
-19.16354751586914,
-23.589908599853516,
-24.683481216430664,
-24.995929718017578,
-25.516677856445312,
-25.542715072631836,
-25.77705192565918,
-26.063465118408203
],
[
-0.2519786059856415,
-1.5017764568328857,
-12.437495231628418,
-15.457839012145996,
-15.744250297546387,
-16.837820053100586,
-17.41064453125,
-17.56686782836914,
-17.61894416809082,
-18.035541534423828
],
[
0.0,
-20.517494201660156,
-24.683483123779297,
-25.67290496826172,
-26.58421516418457,
-27.651750564575195,
-27.781936645507812,
-27.912124633789062,
-28.09438705444336,
-28.445892333984375
],
[
-3.40932747349143e-05,
-10.284820556640625,
-18.252273559570312,
-20.17904281616211,
-21.663175582885742,
-22.027700424194336,
-22.288074493408203,
-22.704673767089844,
-23.12127113342285,
-23.277496337890625
],
[
0.0,
-22.60049057006836,
-25.46460723876953,
-25.829130172729492,
-26.063467025756836,
-27.287227630615234,
-27.391376495361328,
-27.4694881439209,
-27.67778778076172,
-28.055330276489258
],
[
0.0,
-23.902362823486328,
-28.823436737060547,
-29.240036010742188,
-29.31814956665039,
-29.917007446289062,
-30.021160125732422,
-31.21887969970703,
-32.416603088378906,
-32.416603088378906
],
[
0.0,
-28.641170501708984,
-31.947925567626953,
-32.59886169433594,
-33.848655700683594,
-34.109031677246094,
-34.73393249511719,
-35.02033996582031,
-35.02033996582031,
-36.074859619140625
],
[
-0.013183215633034706,
-4.335395336151123,
-19.619365692138672,
-20.035964965820312,
-20.244266510009766,
-21.311800003051758,
-21.441987991333008,
-22.561595916748047,
-23.108383178710938,
-23.264606475830078
],
[
-8.344646857949556e-07,
-14.190400123596191,
-15.9088716506958,
-18.17412567138672,
-18.46053695678711,
-18.46053695678711,
-18.512611389160156,
-18.90317153930664,
-19.059398651123047,
-19.085433959960938
],
[
0.0,
-17.70545196533203,
-18.903175354003906,
-20.829944610595703,
-22.574451446533203,
-22.860862731933594,
-23.069162368774414,
-23.32953643798828,
-23.694061279296875,
-24.188772201538086
],
[
0.0,
-20.022781372070312,
-21.038240432739258,
-21.220502853393555,
-22.496337890625,
-22.769729614257812,
-23.589908599853516,
-23.65500259399414,
-23.94141387939453,
-24.266881942749023
]
]
}
]

View file

@ -4,7 +4,7 @@ import pytest
from fastapi import Response
from unittest.mock import AsyncMock, MagicMock, patch
from src.commons.globals import handler_map
from src.core.base_handler import (
from src.core.model_utils import (
Message,
ChatMessage,
ChatCompletionResponse,
@ -47,7 +47,7 @@ def sample_request(sample_messages):
)
@patch("app.commons.globals.handler_map")
@patch("src.commons.globals.handler_map")
def test_process_messages(mock_hanlder):
messages = sample_messages()
processed = handler_map["Arch-Function"]._process_messages(messages)
@ -68,39 +68,39 @@ def test_process_messages(mock_hanlder):
# [TODO] Review: Update the following test
@patch("app.commons.constants.arch_function_client")
@patch("app.commons.constants.arch_function_hanlder")
@pytest.mark.asyncio
async def test_chat_completion(mock_hanlder, mock_client):
# Mock the model list return for client
mock_client.models.list.return_value = MagicMock(
data=[MagicMock(id="sample_model")]
)
request = sample_request(sample_messages())
# Simulate stream response as list of tokens
mock_response = AsyncMock()
mock_response.__aiter__.return_value = [
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream
]
mock_client.chat.completions.create.return_value = mock_response
# @patch("src.commons.globals.ARCH_CLIENT")
# @patch("src.commons.constants.handler_map")
# @pytest.mark.asyncio
# async def test_chat_completion(mock_hanlder, mock_client):
# # Mock the model list return for client
# mock_client.models.list.return_value = MagicMock(
# data=[MagicMock(id="sample_model")]
# )
# request = sample_request(sample_messages())
# # Simulate stream response as list of tokens
# mock_response = AsyncMock()
# mock_response.__aiter__.return_value = [
# MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]),
# MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream
# ]
# mock_client.chat.completions.create.return_value = mock_response
# Mock the tool formatter
mock_hanlder._format_system_prompt.return_value = "<formatted_tools>"
# # Mock the tool formatter
# mock_hanlder._format_system_prompt.return_value = "<formatted_tools>"
response = Response()
chat_response = await chat_completion(request, response)
# response = Response()
# chat_response = await chat_completion(request, response)
assert isinstance(chat_response, ChatCompletionResponse)
assert chat_response.choices[0].message.content is not None
# assert isinstance(chat_response, ChatCompletionResponse)
# assert chat_response.choices[0].message.content is not None
first_call_args = mock_client.chat.completions.create.call_args_list[0][1]
assert first_call_args["stream"] == True
assert "model" in first_call_args
assert first_call_args["messages"][0]["content"] == "<formatted_tools>"
# first_call_args = mock_client.chat.completions.create.call_args_list[0][1]
# assert first_call_args["stream"] == True
# assert "model" in first_call_args
# assert first_call_args["messages"][0]["content"] == "<formatted_tools>"
# Check that the arguments for the second call to 'create' include the pre-fill completion
second_call_args = mock_client.chat.completions.create.call_args_list[1][1]
assert second_call_args["stream"] == False
assert "model" in second_call_args
assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST
# # Check that the arguments for the second call to 'create' include the pre-fill completion
# second_call_args = mock_client.chat.completions.create.call_args_list[1][1]
# assert second_call_args["stream"] == False
# assert "model" in second_call_args
# assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST

View file

@ -11,9 +11,9 @@ arch_guard_model_type = {
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
# Test for `get_guardrail_handler()` function on `cpu`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "cpu"
@ -34,9 +34,9 @@ def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer
# Test for `get_guardrail_handler()` function on `cuda`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "cuda"
@ -57,9 +57,9 @@ def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenize
# Test for `get_guardrail_handler()` function on `mps`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "mps"