diff --git a/model_server/app/cli.py b/model_server/app/cli.py index ca1b3a0a..42b0c341 100644 --- a/model_server/app/cli.py +++ b/model_server/app/cli.py @@ -1,10 +1,7 @@ import importlib import sys -import os import time import requests -import psutil -import tempfile import subprocess import logging diff --git a/model_server/app/commons/constants.py b/model_server/app/commons/constants.py index 0a77830b..f4c062f8 100644 --- a/model_server/app/commons/constants.py +++ b/model_server/app/commons/constants.py @@ -1,38 +1,83 @@ -import app.commons.globals as glb -import app.commons.utilities as utils -import app.loader as loader +# ========================== Arch-Intent Default Params ========================== +ARCH_INTENT_MODEL_ALIAS = "Arch-Intent" +ARCH_INTENT_INSTRUCTION = "Are there any tools can help?" -from app.function_calling.model_handler import ArchFunctionHandler -from app.prompt_guard.model_handler import ArchGuardHanlder +ARCH_INTENT_TASK_PROMPT = """ +You are a helpful assistant. +""".strip() -logger = utils.get_model_server_logger() -arch_function_hanlder = ArchFunctionHandler() -PREFILL_LIST = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"] -PREFILL_ENABLED = True -TOOL_CALL_TOKEN = "" -arch_function_endpoint = "https://api.fc.archgw.com/v1" -arch_function_client = utils.get_client(arch_function_endpoint) -arch_function_generation_params = { - "temperature": 0.2, - "top_p": 1.0, - "top_k": 50, - "max_tokens": 512, - "stop_token_ids": [151645], - # "top_logprobs": 10, +ARCH_INTENT_TOOL_PROMPT = """ +You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below. + + +{tool_text} + +""".strip() + + +ARCH_INTENT_FORMAT_PROMPT = """ +Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation: +- First line must read 'Yes' or 'No'. +- If yes, a second line must include a comma-separated list of tool indexes. +""".strip() + + +ARCH_INTENT_GENERATION_CONFIG = { + "generation_params": { + "stop_token_ids": [151645], + "max_tokens": 1, + "guided_choice": ["Yes", "No"], + } } -arch_guard_model_type = { - "cpu": "katanemo/Arch-Guard-cpu", - "cuda": "katanemo/Arch-Guard", - "mps": "katanemo/Arch-Guard", + +# ========================== Arch-Function Default Params ========================== +ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function" + +ARCH_FUNCTION_TASK_PROMPT = """ +You are a helpful assistant. +""".strip() + + +ARCH_FUNCTION_TOOL_PROMPT = """ +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{tool_text} + +""".strip() + + +ARCH_FUNCTION_FORMAT_PROMPT = """ +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +""".strip() + +ARCH_FUNCTION_GENERATION_CONFIG = { + "generation_params": { + "temperature": 0.2, + "top_p": 1.0, + "top_k": 50, + "max_tokens": 512, + "stop_token_ids": [151645], + }, + "prefill_params": { + "continue_final_message": True, + "add_generation_prompt": False, + }, + "prefill_prefix": [ + "May", + "Could", + "Sure", + "Definitely", + "Certainly", + "Of course", + "Can", + ], } - -# Model definition -embedding_model = loader.get_embedding_model() -zero_shot_model = loader.get_zero_shot_model() - -prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE]) - -arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict) -# Patterns for function name and parameter parsing diff --git a/model_server/app/commons/globals.py b/model_server/app/commons/globals.py index 6c82ede2..5a9fac29 100644 --- a/model_server/app/commons/globals.py +++ b/model_server/app/commons/globals.py @@ -1,4 +1,65 @@ import app.commons.utilities as utils +from app.commons.constants import * +from app.model_handler.function_calling import ArchIntentHandler, ArchFunctionHandler +from app.model_handler.guardrails import ArchGuardHanlder -DEVICE = utils.get_device() +from transformers import AutoTokenizer, AutoModelForSequenceClassification +from optimum.intel import OVModelForSequenceClassification +from openai import OpenAI + + +logger = utils.get_model_server_logger() + + +def get_guardrail_handler(): + device = utils.get_device() + + model_class, model_name = None, None + if device == "cpu": + model_class = OVModelForSequenceClassification + model_name = "katanemo/Arch-Guard-cpu" + else: + model_class = AutoModelForSequenceClassification + if device == "cuda": + model_name = "katanemo/Arch-Guard" + else: + model_name = "katanemo/Arch-Guard" + + guardrail_dict = { + "device": device, + "model_name": model_name, + "tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True), + "model": model_class.from_pretrained( + model_name, device_map=device, low_cpu_mem_usage=True + ), + } + + return ArchGuardHanlder(model_dict=guardrail_dict) + + +# Define the client +ARCH_CLIENT = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY") + + +# Define model handlers +handler_map = { + "Arch-Intent": ArchIntentHandler( + ARCH_CLIENT, + ARCH_INTENT_MODEL_ALIAS, + ARCH_INTENT_TASK_PROMPT, + ARCH_INTENT_TOOL_PROMPT, + ARCH_INTENT_FORMAT_PROMPT, + ARCH_INTENT_INSTRUCTION, + **ARCH_INTENT_GENERATION_CONFIG, + ), + "Arch-Function": ArchFunctionHandler( + ARCH_CLIENT, + ARCH_FUNCTION_MODEL_ALIAS, + ARCH_FUNCTION_TASK_PROMPT, + ARCH_FUNCTION_TOOL_PROMPT, + ARCH_FUNCTION_FORMAT_PROMPT, + **ARCH_FUNCTION_GENERATION_CONFIG, + ), + "Arch-Guard": get_guardrail_handler(), +} diff --git a/model_server/app/commons/utilities.py b/model_server/app/commons/utilities.py index d920107d..0ef1a18f 100644 --- a/model_server/app/commons/utilities.py +++ b/model_server/app/commons/utilities.py @@ -1,11 +1,7 @@ import os -import yaml import torch -import string import logging -from openai import OpenAI - logger_instance = None @@ -31,11 +27,6 @@ def get_device(): return device -def get_client(endpoint): - client = OpenAI(base_url=endpoint, api_key="EMPTY") - return client - - def get_model_server_logger(): global logger_instance @@ -72,12 +63,3 @@ def get_model_server_logger(): # Initialize the logger instance after configuring handlers logger_instance = logging.getLogger("model_server_logger") return logger_instance - - -def remove_punctuations(s): - s = s.translate(str.maketrans(string.punctuation, " " * len(string.punctuation))) - return " ".join(s.split()).lower() - - -def get_label_map(labels): - return {remove_punctuations(label): label for label in labels} diff --git a/model_server/app/function_calling/model_handler.py b/model_server/app/function_calling/model_handler.py deleted file mode 100644 index e1da914c..00000000 --- a/model_server/app/function_calling/model_handler.py +++ /dev/null @@ -1,137 +0,0 @@ -import json -import random - -from typing import Any, Dict, List - - -ARCH_FUNCTION_CALLING_TASK_PROMPT = """ -You are a helpful assistant. -""".strip() - - -ARCH_FUNCTION_CALLING_TOOL_PROMPT = """ -# Tools - -You may call one or more functions to assist with the user query. - -You are provided with function signatures within XML tags: - -{tool_text} - -""".strip() - - -ARCH_FUNCTION_CALLING_FORMAT_PROMPT = """ -For each function call, return a json object with function name and arguments within XML tags: - -{"name": , "arguments": } - -""".strip() - - -class ArchFunctionHandler: - def __init__(self) -> None: - super().__init__() - - def _format_system(self, tools: List[Dict[str, Any]]): - def convert_tools(tools): - return "\n".join([json.dumps(tool) for tool in tools]) - - tool_text = convert_tools(tools) - - system_prompt = ( - ARCH_FUNCTION_CALLING_TASK_PROMPT - + "\n\n" - + ARCH_FUNCTION_CALLING_TOOL_PROMPT.format(tool_text=tool_text) - + "\n\n" - + ARCH_FUNCTION_CALLING_FORMAT_PROMPT - ) - - return system_prompt - - def _add_execution_results_prompting( - self, - messages: list[dict], - execution_results: list, - ) -> dict: - content = [] - for result in execution_results: - content.append(f"\n{json.dumps(result)}\n") - - content = "\n".join(content) - messages.append({"role": "user", "content": content}) - - return messages - - def extract_tool_calls(self, content: str): - tool_calls = [] - - flag = False - for line in content.split("\n"): - if "" == line: - flag = True - elif "" == line: - flag = False - else: - if flag: - try: - tool_content = json.loads(line) - except Exception: - fixed_content = self.fix_json_string(line) - try: - tool_content = json.loads(fixed_content) - except json.JSONDecodeError: - return content - - tool_calls.append( - { - "id": f"call_{random.randint(1000, 10000)}", - "type": "function", - "function": { - "name": tool_content["name"], - "arguments": tool_content["arguments"], - }, - } - ) - - flag = False - - return tool_calls - - def fix_json_string(self, json_str: str): - # Remove any leading or trailing whitespace or newline characters - json_str = json_str.strip() - - # Stack to keep track of brackets - stack = [] - - # Clean string to collect valid characters - fixed_str = "" - - # Dictionary for matching brackets - matching_bracket = {")": "(", "}": "{", "]": "["} - - # Dictionary for the opposite of matching_bracket - opening_bracket = {v: k for k, v in matching_bracket.items()} - - for char in json_str: - if char in "{[(": - stack.append(char) - fixed_str += char - elif char in "}])": - if stack and stack[-1] == matching_bracket[char]: - stack.pop() - fixed_str += char - else: - # Ignore the unmatched closing brackets - continue - else: - fixed_str += char - - # If there are unmatched opening brackets left in the stack, add corresponding closing brackets - while stack: - unmatched_opening = stack.pop() - fixed_str += opening_bracket[unmatched_opening] - - # Attempt to parse the corrected string to ensure it’s valid JSON - return fixed_str.replace("'", '"') diff --git a/model_server/app/function_calling/model_utils.py b/model_server/app/function_calling/model_utils.py deleted file mode 100644 index 6e7b926c..00000000 --- a/model_server/app/function_calling/model_utils.py +++ /dev/null @@ -1,157 +0,0 @@ -import json -import hashlib -import app.commons.constants as const -import random -from fastapi import Response -from pydantic import BaseModel -from app.commons.utilities import get_model_server_logger -from typing import Any, Dict, List, Optional - - -logger = get_model_server_logger() - - -class Message(BaseModel): - role: Optional[str] = "" - content: Optional[str] = "" - tool_calls: Optional[List[Dict[str, Any]]] = [] - tool_call_id: Optional[str] = "" - - -class ChatMessage(BaseModel): - messages: list[Message] - tools: List[Dict[str, Any]] - - -class Choice(BaseModel): - message: Message - finish_reason: Optional[str] = "stop" - index: Optional[int] = 0 - - -class ChatCompletionResponse(BaseModel): - choices: List[Choice] - model: Optional[str] = "Arch-Function" - created: Optional[str] = "" - id: Optional[str] = "" - object: Optional[str] = "chat_completion" - - -def process_messages(history: list[Message]): - updated_history = [] - for hist in history: - if hist.tool_calls: - if len(hist.tool_calls) > 1: - error_msg = f"Only one tool call is supported, tools counts: {len(hist.tool_calls)}" - logger.error(error_msg) - raise ValueError(error_msg) - tool_call_str = json.dumps(hist.tool_calls[0]["function"]) - updated_history.append( - { - "role": "assistant", - "content": f"\n{tool_call_str}\n", - } - ) - elif hist.role == "tool": - updated_history.append( - { - "role": "user", - "content": f"\n{hist.content}\n", - } - ) - else: - updated_history.append({"role": hist.role, "content": hist.content}) - return updated_history - - -async def chat_completion(req: ChatMessage, res: Response): - logger.info("starting request") - - tools_encoded = const.arch_function_hanlder._format_system(req.tools) - - messages = [{"role": "system", "content": tools_encoded}] - - updated_history = process_messages(req.messages) - for message in updated_history: - messages.append({"role": message["role"], "content": message["content"]}) - - client_model_name = const.arch_function_client.models.list().data[0].id - - logger.info( - f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}" - ) - - # Retrieve the first token, handling the Stream object carefully - - try: - resp = const.arch_function_client.chat.completions.create( - messages=messages, - model=client_model_name, - stream=const.PREFILL_ENABLED, - extra_body=const.arch_function_generation_params, - ) - except Exception as e: - logger.error(f"model_server <= arch_function: error: {e}") - raise - - if const.PREFILL_ENABLED: - first_token_content = "" - for token in resp: - first_token_content = token.choices[ - 0 - ].delta.content.strip() # Clean up the content - if first_token_content: # Break if it's non-empty - break - - # Check if the first token requires tool call handling - if first_token_content != const.TOOL_CALL_TOKEN: - # Engage pre-filling response if no tool call is indicated - resp.close() - logger.info("Tool call is not found! Engage pre filling") - prefill_content = random.choice(const.PREFILL_LIST) - messages.append({"role": "assistant", "content": prefill_content}) - - # Send a new completion request with the updated messages - # the model will continue the final message in the chat instead of starting a new one - # disable add_generation_prompt which tells the template to add tokens that indicate the start of a bot response. - extra_body = { - **const.arch_function_generation_params, - "continue_final_message": True, - "add_generation_prompt": False, - } - pre_fill_resp = const.arch_function_client.chat.completions.create( - messages=messages, - model=client_model_name, - stream=False, - extra_body=extra_body, - ) - full_response = pre_fill_resp.choices[0].message.content - else: - # Initialize full response and iterate over tokens to gather the full response - full_response = first_token_content - for token in resp: - if hasattr(token.choices[0].delta, "content"): - full_response += token.choices[0].delta.content - else: - logger.info("Stream is disabled, not engaging pre-filling") - full_response = resp.choices[0].message.content - - tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response) - - if tool_calls: - message = Message(content="", tool_calls=tool_calls) - else: - message = Message(content=full_response, tool_calls=[]) - choice = Choice(message=message) - chat_completion_response = ChatCompletionResponse( - choices=[choice], model=client_model_name - ) - - logger.info( - f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}" - ) - logger.info( - f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}" - ) - - return chat_completion_response diff --git a/model_server/app/loader.py b/model_server/app/loader.py deleted file mode 100644 index 2be8777e..00000000 --- a/model_server/app/loader.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -import app.commons.globals as glb - -from transformers import AutoTokenizer, AutoModel, pipeline -from optimum.onnxruntime import ( - ORTModelForFeatureExtraction, - ORTModelForSequenceClassification, -) -import app.commons.utilities as utils -import torch -from transformers import AutoModelForSequenceClassification, AutoTokenizer -from optimum.intel import OVModelForSequenceClassification - - -logger = utils.get_model_server_logger() - - -def get_embedding_model( - model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5"), -): - logger.info("Loading Embedding Model...") - - if glb.DEVICE != "cuda": - model = ORTModelForFeatureExtraction.from_pretrained( - model_name, file_name="onnx/model.onnx" - ) - else: - model = AutoModel.from_pretrained(model_name, device_map=glb.DEVICE) - - embedding_model = { - "model_name": model_name, - "tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True), - "model": model, - } - - return embedding_model - - -def get_zero_shot_model( - model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"), -): - logger.info("Loading Zero-shot Model...") - - if glb.DEVICE != "cuda": - model = ORTModelForSequenceClassification.from_pretrained( - model_name, file_name="onnx/model.onnx" - ) - else: - model = model_name - - zero_shot_model = { - "model_name": model_name, - "tokenizer": AutoTokenizer.from_pretrained(model_name), - "model": model, - } - - zero_shot_model["pipeline"] = pipeline( - "zero-shot-classification", - model=zero_shot_model["model"], - tokenizer=zero_shot_model["tokenizer"], - device=glb.DEVICE, - ) - - return zero_shot_model - - -def get_prompt_guard(model_name): - logger.info("Loading Guard Model...") - - if glb.DEVICE == "cpu": - model_class = OVModelForSequenceClassification - else: - model_class = AutoModelForSequenceClassification - - prompt_guard = { - "device": glb.DEVICE, - "model_name": model_name, - "tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True), - "model": model_class.from_pretrained( - model_name, device_map=glb.DEVICE, low_cpu_mem_usage=True - ), - } - - return prompt_guard diff --git a/model_server/app/main.py b/model_server/app/main.py index 43be1f74..79d94f0d 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -1,21 +1,10 @@ import os -import time -import torch -import app.commons.utilities as utils -import app.commons.globals as glb -import app.prompt_guard.model_utils as guard_utils -from typing import List, Dict -from pydantic import BaseModel -from fastapi import FastAPI, Response, HTTPException, Request -from app.function_calling.model_utils import ChatMessage - -from app.commons.constants import embedding_model, zero_shot_model, arch_guard_handler -from app.function_calling.model_utils import ( - chat_completion as arch_function_chat_completion, -) -from unittest.mock import patch +from app.commons.globals import handler_map +from app.model_handler.function_calling import ChatMessage +from app.model_handler.guardrails import GuardRequest +from fastapi import FastAPI, Response, Request from opentelemetry import trace from opentelemetry.sdk.trace import TracerProvider from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor @@ -23,6 +12,7 @@ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExport from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.resources import Resource + resource = Resource.create( { "service.name": "model-server", @@ -34,10 +24,6 @@ trace.set_tracer_provider(TracerProvider(resource=resource)) tracer = trace.get_tracer(__name__) -logger = utils.get_model_server_logger() - -logger.info(f"Ready to serve traffic. available device: {glb.DEVICE}") - app = FastAPI() FastAPIInstrumentor().instrument_app(app) @@ -53,28 +39,6 @@ otlp_exporter = OTLPSpanExporter( trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter)) -class EmbeddingRequest(BaseModel): - input: str - model: str - - -class GuardRequest(BaseModel): - input: str - task: str - - -class ZeroShotRequest(BaseModel): - input: str - labels: List[str] - model: str - - -class HallucinationRequest(BaseModel): - prompt: str - parameters: Dict - model: str - - @app.get("/healthz") async def healthz(): return {"status": "ok"} @@ -84,172 +48,40 @@ async def healthz(): async def models(): return { "object": "list", - "data": [{"id": embedding_model["model_name"], "object": "model"}], + "data": [{"id": model_name, "object": "model"} for model_name in handler_map], } -@app.post("/embeddings") -async def embedding(req: EmbeddingRequest, res: Response): - logger.info(f"Embedding req: {req}") - - if req.model != embedding_model["model_name"]: - raise HTTPException(status_code=400, detail="unknown model: " + req.model) - - start_time = time.perf_counter() - - encoded_input = embedding_model["tokenizer"]( - req.input, padding=True, truncation=True, return_tensors="pt" - ).to(glb.DEVICE) - - with torch.no_grad(): - embeddings = embedding_model["model"](**encoded_input) - embeddings = embeddings[0][:, 0] - embeddings = ( - torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().cpu().numpy() - ) - - logger.info(f"Embedding Call Complete Time: {time.perf_counter()-start_time}") - - data = [ - {"object": "embedding", "embedding": embedding, "index": index + 1} - for index, embedding in enumerate(embeddings.tolist()) - ] - - usage = { - "prompt_tokens": 0, - "total_tokens": 0, - } - - return {"data": data, "model": req.model, "object": "list", "usage": usage} - - -@app.post("/guard") -async def guard(req: GuardRequest, res: Response, max_num_words=300): - """ - Take input as text and return the prediction of toxic and jailbreak - """ - - if req.task in ["both", "toxic", "jailbreak"]: - arch_guard_handler.task = req.task - else: - raise NotImplementedError(f"{req.task} is not supported!") - - start_time = time.perf_counter() - - if len(req.input.split()) < max_num_words: - guard_result = arch_guard_handler.guard_predict(req.input) - else: - # text is long, split into chunks - chunks = guard_utils.split_text_into_chunks(req.input) - - guard_result = { - "jailbreak_prob": [], - "time": 0, - "jailbreak_verdict": False, - "toxic_sentence": [], - "jailbreak_sentence": [], - } - - for chunk in chunks: - chunk_result = arch_guard_handler.guard_predict(chunk) - guard_result["time"] += chunk_result["time"] - if chunk_result[f"{arch_guard_handler.task}_verdict"]: - guard_result[f"{arch_guard_handler.task}_verdict"] = True - guard_result[f"{arch_guard_handler.task}_sentence"].append( - chunk_result[f"{arch_guard_handler.task}_sentence"] - ) - guard_result[f"{arch_guard_handler.task}_prob"].append( - chunk_result[f"{arch_guard_handler.task}_prob"].item() - ) - - logger.info(f"Time taken for Guard: {time.perf_counter() - start_time}") - - return guard_result - - -@app.post("/zeroshot") -async def zeroshot(req: ZeroShotRequest, res: Response): - logger.info(f"zero-shot request: {req}") - - if req.model != zero_shot_model["model_name"]: - raise HTTPException(status_code=400, detail="unknown model: " + req.model) - - classifier = zero_shot_model["pipeline"] - - label_map = utils.get_label_map(req.labels) - - start_time = time.perf_counter() - - predictions = classifier( - req.input, candidate_labels=list(label_map.keys()), multi_label=True - ) - - logger.info(f"zero-shot taking {time.perf_counter() - start_time} seconds") - - predicted_class = label_map[predictions["labels"][0]] - predicted_score = predictions["scores"][0] - - scores = { - label_map[label]: score - for label, score in zip(predictions["labels"], predictions["scores"]) - } - - predicted_class = label_map[predictions["labels"][0]] - - return { - "predicted_class": predicted_class, - "predicted_class_score": predicted_score, - "scores": scores, - "model": req.model, - } - - -@app.post("/hallucination") -@patch("app.loader.glb.DEVICE", "cpu") # Mock the device to 'cpu' -async def hallucination(req: HallucinationRequest, res: Response): - """ - Take input as text and return the prediction of hallucination for each parameter - """ - logger.info(f"hallucination request: {req}") - if req.model != zero_shot_model["model_name"]: - raise HTTPException(status_code=400, detail="unknown model: " + req.model) - - start_time = time.perf_counter() - classifier = zero_shot_model["pipeline"] - - if "messages" in req.parameters: - req.parameters.pop("messages") - - candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()} - - predictions = classifier( - req.prompt, - candidate_labels=list(candidate_labels.keys()), - hypothesis_template="{}", - multi_label=True, - ) - - params_scores = { - candidate_labels[label]: score - for label, score in zip(predictions["labels"], predictions["scores"]) - } - - logger.info( - f"hallucination time cost: {params_scores}, taking {time.perf_counter() - start_time} seconds" - ) - - return { - "params_scores": params_scores, - "model": req.model, - } - - -@app.post("/v1/chat/completions") -async def chat_completion(req: ChatMessage, res: Response, request: Request): +@app.post("/function_calling") +async def function_calling(req: ChatMessage, res: Response, request: Request): try: - result = await arch_function_chat_completion(req, res) - return result + intent_result = await handler_map["Arch-Intent"].chat_completion(req) + + if intent_result.choices[0].message.content == "Yes": + try: + function_result = await handler_map["Arch-Function"].chat_completion( + req + ) + return function_result + except Exception as e: + # [TODO] + # logger.error(f"Error in chat_completion from `Arch-Function`: {e}") + res.status_code = 500 + return {"error": f"[Arch-Function] - {e}"} + except Exception as e: - logger.error(f"Error in chat_completion: {e}") + # [TODO] + # logger.error(f"Error in chat_completion from `Arch-Intent`: {e}") res.status_code = 500 - return {"error": "Internal server error"} + return {"error": f"[Arch-Intent] - {e}"} + + +@app.post("/guardrails") +async def guardrails(req: GuardRequest, res: Response, max_num_words=300): + try: + guard_result = handler_map["Arch-Guard"].predict(req) + return guard_result + except Exception as e: + # [TODO] + res.status_code = 500 + return {"error": f"[Arch-Guard] - {e}"} diff --git a/model_server/app/function_calling/__init__.py b/model_server/app/model_handler/__init__.py similarity index 100% rename from model_server/app/function_calling/__init__.py rename to model_server/app/model_handler/__init__.py diff --git a/model_server/app/model_handler/function_calling.py b/model_server/app/model_handler/function_calling.py new file mode 100644 index 00000000..0cee18ca --- /dev/null +++ b/model_server/app/model_handler/function_calling.py @@ -0,0 +1,415 @@ +import json +import random +import builtins + +from openai import OpenAI +from pydantic import BaseModel +from typing import Any, Dict, List, Optional +from overrides import override, final + + +SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"] + + +class Message(BaseModel): + role: Optional[str] = "" + content: Optional[str] = "" + tool_call_id: Optional[str] = "" + tool_calls: Optional[List[Dict[str, Any]]] = [] + + +class ChatMessage(BaseModel): + messages: list[Message] + tools: List[Dict[str, Any]] + + +class Choice(BaseModel): + id: Optional[int] = 0 + message: Message + finish_reason: Optional[str] = "stop" + + +class ChatCompletionResponse(BaseModel): + id: Optional[int] = 0 + object: Optional[str] = "chat_completion" + created: Optional[str] = "" + model: str + choices: List[Choice] + + +class ArchBaseHandler: + def __init__( + self, + client: OpenAI, + model_name: str, + task_prompt: str, + tool_prompt: str, + format_prompt: str, + generation_params: Dict, + ): + self.client = client + + self.model_name = model_name + + self.task_prompt = task_prompt + self.tool_prompt = tool_prompt + self.format_prompt = format_prompt + + self.generation_params = generation_params + + def _convert_tools(self, tools: List[Dict[str, Any]]): + raise NotImplementedError() + + @final + def _format_system(self, tools: List[Dict[str, Any]]): + tool_text = self._convert_tools(tools) + + system_prompt = ( + self.task_prompt + + "\n\n" + + self.tool_prompt.format(tool_text=tool_text) + + "\n\n" + + self.format_prompt + ) + + return system_prompt + + @final + def _process_messages( + self, + messages: List[Message], + tools: List[Dict[str, Any]] = None, + extra_instructions: str = None, + ): + processed_messages = [] + + if tools: + processed_messages.append( + {"role": "system", "content": self._format_system(tools)} + ) + + for message in messages: + role, content, tool_calls = ( + message.role, + message.content, + message.tool_calls, + ) + + if tool_calls: + # [TODO] Extend to support multiple function calls + role = "assistant" + content = f"\n{json.dumps(tool_calls[0]['function'])}\n" + elif message.role == "tool": + role = "user" + content = ( + f"\n{json.dumps(message.content)}\n" + ) + + processed_messages.append({"role": role, "content": content}) + + assert processed_messages[-1]["role"] == "user" + + if extra_instructions: + processed_messages[-1]["content"] += extra_instructions + + return processed_messages + + async def chat_completion(self, req: ChatMessage): + raise NotImplementedError() + + +class ArchIntentHandler(ArchBaseHandler): + def __init__( + self, + client: OpenAI, + model_name: str, + task_prompt: str, + tool_prompt: str, + format_prompt: str, + intent_instruction: str, + generation_params: Dict, + ): + super().__init__( + client, + model_name, + task_prompt, + tool_prompt, + format_prompt, + generation_params, + ) + + self.intent_instruction = intent_instruction + + @override + def _convert_tools(self, tools: List[Dict[str, Any]]): + converted = [ + json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools) + ] + return "\n".join(converted) + + @override + async def chat_completion(self, req: ChatMessage): + """ + Note: Currently only support vllm inference + """ + + messages = self._process_messages( + req.messages, req.tools, self.intent_instruction + ) + + model_response = self.client.chat.completions.create( + messages=messages, + model=self.model_name, + stream=False, + extra_body=self.generation_params, + ) + + model_response = Message(content=model_response, tool_calls=[]) + + chat_completion_response = ChatCompletionResponse( + choices=[Choice(message=model_response)], model=self.model_name + ) + + return chat_completion_response + + +class ArchFunctionHandler(ArchBaseHandler): + def __init__( + self, + client: OpenAI, + model_name: str, + task_prompt: str, + tool_prompt: str, + format_prompt: str, + generation_params: Dict, + prefill_params: Dict, + prefill_prefix: List, + ): + super().__init__( + client, + model_name, + task_prompt, + tool_prompt, + format_prompt, + generation_params, + ) + + self.prefill_params = prefill_params + self.prefill_prefix = prefill_prefix + + # Predefine data types for verification. Only support Python for now. + # [TODO] Extend the list of support data types + self.support_data_types = { + type_name: getattr(builtins, type_name) for type_name in SUPPORT_DATA_TYPES + } + + @override + def _convert_tools(self, tools: List[Dict[str, Any]]): + converted = [json.dumps(tool) for tool in tools] + return "\n".join(converted) + + def _fix_json_string(self, json_str: str): + # Remove any leading or trailing whitespace or newline characters + json_str = json_str.strip() + + # Stack to keep track of brackets + stack = [] + + # Clean string to collect valid characters + fixed_str = "" + + # Dictionary for matching brackets + matching_bracket = {")": "(", "}": "{", "]": "["} + + # Dictionary for the opposite of matching_bracket + opening_bracket = {v: k for k, v in matching_bracket.items()} + + for char in json_str: + if char in "{[(": + stack.append(char) + fixed_str += char + elif char in "}])": + if stack and stack[-1] == matching_bracket[char]: + stack.pop() + fixed_str += char + else: + # Ignore the unmatched closing brackets + continue + else: + fixed_str += char + + # If there are unmatched opening brackets left in the stack, add corresponding closing brackets + while stack: + unmatched_opening = stack.pop() + fixed_str += opening_bracket[unmatched_opening] + + # Attempt to parse the corrected string to ensure it’s valid JSON + return fixed_str.replace("'", '"') + + def _extract_tool_calls(self, content: str): + tool_calls, is_valid, error_message = [], True, "" + + flag = False + for line in content.split("\n"): + if "" == line: + flag = True + elif "" == line: + flag = False + else: + if flag: + try: + tool_content = json.loads(line) + except Exception as e: + fixed_content = self._fix_json_string(line) + try: + tool_content = json.loads(fixed_content) + except Exception: + tool_calls, is_valid, error_message = [], False, e + return tool_calls, is_valid, error_message + + tool_calls.append( + { + "id": f"call_{random.randint(1000, 10000)}", + "type": "function", + "function": { + "name": tool_content["name"], + "arguments": tool_content["arguments"], + }, + } + ) + + flag = False + + return tool_calls, is_valid, error_message + + def _verify_tool_calls( + self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]] + ): + is_valid, error_tool_call, error_message = True, None, "" + + functions = {} + for tool in tools: + if tool["type"] == "function": + functions[tool["function"]["name"]] = tool["function"]["parameters"] + + for tool_call in tool_calls: + func_name, func_args = ( + tool_call["function"]["name"], + tool_call["function"]["arguments"], + ) + + # Check whether the function is available or not + if func_name not in functions: + is_valid = False + error_message = f"{func_name} is not defined!" + return is_valid, error_message + else: + # Check if all the requried parameters can be found in the tool calls + for required_param in functions[func_name].get("required", []): + if required_param not in func_args: + is_valid = False + error_tool_call = tool_call + error_message = f"`{required_param}` is requried by the function `{func_name}` but not found in the tool call!" + return is_valid, error_tool_call, error_message + + # Verify the data type of each parameter in the tool calls + for param_name, param_value in func_args: + data_type = functions[func_name]["properties"][param_name]["type"] + + if data_type in self.support_data_types: + if not isinstance( + param_value, self.support_data_types[data_type] + ): + is_valid = False + error_tool_call = tool_call + error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`." + return is_valid, error_tool_call, error_message + + return is_valid, error_tool_call, error_message + + @override + async def chat_completion(self, req: ChatMessage, enable_prefilling=True): + """ + Note: Currently only support vllm inference + """ + + messages = self._process_messages(req.messages, req.tools) + + # Retrieve the first token, handling the Stream object carefully + response = self.client.chat.completions.create( + messages=messages, + model=self.model_name, + stream=enable_prefilling, + extra_body=self.generation_params, + ) + + model_response = "" + + if enable_prefilling: + has_tool_call = None + + model_response = "" + for token in response: + token_content = token.choices[0].delta.content.strip() + + if has_tool_call is None and token_content != "": + has_tool_call = False + response.close() + break + else: + has_tool_call = True + + if has_tool_call is True: + model_response += token_content + + # start parameter gathering if the model is not generating a tool call + if has_tool_call is False: + messages.append( + { + "role": "assistant", + "content": random.choice(self.prefill_prefix), + } + ) + + prefill_response = self.client.chat.completions.create( + messages=messages, + model=self.model_name, + stream=False, + extra_body={ + **self.generation_params, + **self.prefill_params, + }, + ) + + model_response = prefill_response.choices[0].message.content + else: + model_response = response.choices[0].message.content + + tool_calls, is_valid, error_message = self._extract_tool_calls(model_response) + + if tool_calls: + is_valid, error_tool_call, error_message = self._verify_tool_calls( + tools=req.tools, tool_calls=tool_calls + ) + + # [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately + if is_valid: + model_response = Message(content="", tool_calls=tool_calls) + # else: + + else: + model_response = Message(content=model_response, tool_calls=[]) + + chat_completion_response = ChatCompletionResponse( + choices=[Choice(message=model_response)], model=self.model_name + ) + + # [TODO] Review: define the protocol to collect debugging output + # logger.info( + # f"model_server <= arch_function: (tool_calls): {json.dumps([tool_call['function'] for tool_call in tool_calls])}" + # ) + # logger.info( + # f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}" + # ) + + return chat_completion_response diff --git a/model_server/app/model_handler/guardrails.py b/model_server/app/model_handler/guardrails.py new file mode 100644 index 00000000..07aec8fb --- /dev/null +++ b/model_server/app/model_handler/guardrails.py @@ -0,0 +1,95 @@ +import time +import torch +import numpy as np + +from pydantic import BaseModel + + +class GuardRequest(BaseModel): + input: str + task: str + + +class ArchGuardHanlder: + def __init__(self, model_dict): + self.model = model_dict["model"] + self.tokenizer = model_dict["tokenizer"] + self.device = model_dict["device"] + + self.support_tasks = {"jailbreak": {"positive_class": 2, "threshold": 0.5}} + + def _split_text_into_chunks(self, text, max_num_words=300): + """ + Split the text into chunks of `max_num_words` words + """ + words = text.split() # Split text into words + + chunks = [ + " ".join(words[i : i + max_num_words]) + for i in range(0, len(words), max_num_words) + ] + + return chunks + + @staticmethod + def softmax(x): + return np.exp(x) / np.exp(x).sum(axis=0) + + def _predict_text(self, task, text, max_length=512): + inputs = self.tokenizer( + text, truncation=True, max_length=max_length, return_tensors="pt" + ).to(self.device) + + with torch.no_grad(): + logits = self.model(**inputs).logits.cpu().detach().numpy()[0] + prob = ArchGuardHanlder.softmax(logits)[ + self.support_tasks[task]["positive_class"] + ] + + if prob > self.support_tasks[task]["threshold"]: + verdict = True + sentence = text + else: + verdict = False + sentence = None + + result_dict = { + "prob": prob.item(), + "verdict": verdict, + "sentence": sentence, + } + + return result_dict + + def predict(self, req: GuardRequest, max_num_words=300): + """ + Note: currently only support jailbreak check + """ + + if req.task not in self.support_tasks: + raise NotImplementedError(f"{req.task} is not supported!") + + guard_result = { + "prob": [], + "verdict": False, + "sentence": [], + } + + start_time = time.perf_counter() + + if len(req.input.split()) < max_num_words: + guard_result = self._predict_text(req.task, req.input) + else: + # split into chunks if text is long + text_chunks = self._split_text_into_chunks(req.input) + + for chunk in text_chunks: + chunk_result = self._predict_text(req.task, chunk) + if chunk_result["verdict"]: + guard_result["verdict"] = True + guard_result["sentence"].append(chunk_result["sentence"]) + guard_result["prob"].append(chunk_result["prob"].item()) + + guard_result["latency"] = time.perf_counter() - start_time + + return guard_result diff --git a/model_server/app/function_calling/hallucination_handler.py b/model_server/app/model_handler/hallucination_handler.py similarity index 80% rename from model_server/app/function_calling/hallucination_handler.py rename to model_server/app/model_handler/hallucination_handler.py index 544782fb..60eda200 100644 --- a/model_server/app/function_calling/hallucination_handler.py +++ b/model_server/app/model_handler/hallucination_handler.py @@ -1,8 +1,6 @@ -import json import math import torch -import random -from typing import Any, Dict, List, Tuple +from typing import Dict, List, Tuple import itertools from enum import Enum @@ -74,26 +72,6 @@ def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]: return entropy.item(), varentropy.item() -def is_parameter_property( - function_description: Dict, parameter_name: str, property_name: str -) -> bool: - """ - Check if a parameter in an API description has a specific property. - - Args: - function_description (dict): The API description in JSON format. - parameter_name (str): The name of the parameter to check. - property_name (str): The property to look for (e.g., 'format', 'default'). - - Returns: - bool: True if the parameter has the specified property, False otherwise. - """ - parameters = function_description.get("properties", {}) - parameter_info = parameters.get(parameter_name, {}) - - return property_name in parameter_info - - class HallucinationStateHandler: """ A class to handle the state of hallucination detection in token processing. @@ -111,7 +89,7 @@ class HallucinationStateHandler: token_probs_map (list): List mapping tokens to their entropy and variance of entropy. """ - def __init__(self, response_iterator=None, function=None): + def __init__(self, response_iterator=None): """ Initializes the HallucinationStateHandler with default values. """ @@ -126,19 +104,6 @@ class HallucinationStateHandler: self.parameter_name: List[str] = [] self.token_probs_map: List[Tuple[str, float, float]] = [] self.response_iterator = response_iterator - self._process_function(function) - - def _process_function(self, function): - self.function = function - if self.function is None: - raise ValueError("API descriptions not set.") - parameter_names = {} - for func in self.function: - func_name = func["name"] - parameters = func["parameters"]["properties"] - parameter_names[func_name] = list(parameters.keys()) - self.function_description = parameter_names - self.function_properties = {x["name"]: x["parameters"] for x in self.function} def append_and_check_token_hallucination(self, token, logprob): """ @@ -237,6 +202,8 @@ class HallucinationStateHandler: # checking if the token is a value token and is not empty if self.tokens[-1].strip() not in ['"', ""]: self.mask.append(MaskToken.PARAMETER_VALUE) + + # [TODO] Review: update the following code: `is_parameter_property` should not be here # checking if the parameter doesn't have default and the token is the first parameter value token if ( len(self.mask) > 1 @@ -299,26 +266,3 @@ class HallucinationStateHandler: if self.mask and self.mask[-1] == token else 0 ) - - def _is_function_name_hallucinated(self): - """ - Checks the extracted function name against the function descriptions. - Detects hallucinations if the function name is not found. - """ - f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME) - self.function_name = "".join(self.tokens[:-1][-f_len:]) - if self.function_name not in self.function_description.keys(): - self.error_type = "function_name" - self.error_message = f"Function name '{self.function_name}' not found in given function descriptions." - - def _is_parameter_name_hallucinated(self): - """ - Checks the extracted parameter name against the function descriptions. - Detects hallucinations if the parameter name is not found. - """ - p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME) - parameter_name = "".join(self.tokens[:-1][-p_len:]) - self.parameter_name.append(parameter_name) - if parameter_name not in self.function_description[self.function_name]: - self.error_type = "parameter_name" - self.error_message = f"Parameter name '{parameter_name}' not found in given function descriptions." diff --git a/model_server/app/prompt_guard/model_handler.py b/model_server/app/prompt_guard/model_handler.py deleted file mode 100644 index a200679b..00000000 --- a/model_server/app/prompt_guard/model_handler.py +++ /dev/null @@ -1,42 +0,0 @@ -import time -import torch -import app.prompt_guard.model_utils as model_utils - - -class ArchGuardHanlder: - def __init__(self, model_dict, threshold=0.5): - self.task = "jailbreak" - self.positive_class = 2 - - self.model = model_dict["model"] - self.tokenizer = model_dict["tokenizer"] - self.device = model_dict["device"] - - self.threshold = threshold - - def guard_predict(self, input_text, max_length=512): - start_time = time.perf_counter() - - inputs = self.tokenizer( - input_text, truncation=True, max_length=max_length, return_tensors="pt" - ).to(self.device) - - with torch.no_grad(): - logits = self.model(**inputs).logits.cpu().detach().numpy()[0] - prob = model_utils.softmax(logits)[self.positive_class] - - if prob > self.threshold: - verdict = True - sentence = input_text - else: - verdict = False - sentence = None - - result_dict = { - f"{self.task}_prob": prob.item(), - f"{self.task}_verdict": verdict, - f"{self.task}_sentence": sentence, - "time": time.perf_counter() - start_time, - } - - return result_dict diff --git a/model_server/app/prompt_guard/model_utils.py b/model_server/app/prompt_guard/model_utils.py deleted file mode 100644 index 0db2a72f..00000000 --- a/model_server/app/prompt_guard/model_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -import numpy as np - - -def split_text_into_chunks(text, max_words=300): - """ - Max number of tokens for tokenizer is 512 - Split the text into chunks of 300 words (as approximation for tokens) - """ - words = text.split() # Split text into words - # Estimate token count based on word count (1 word ≈ 1 token) - chunk_size = max_words # Use the word count as an approximation for tokens - chunks = [ - " ".join(words[i : i + chunk_size]) for i in range(0, len(words), chunk_size) - ] - return chunks - - -def softmax(x): - return np.exp(x) / np.exp(x).sum(axis=0) diff --git a/model_server/app/prompt_guard/__init__.py b/model_server/app/tests/__init__.py similarity index 100% rename from model_server/app/prompt_guard/__init__.py rename to model_server/app/tests/__init__.py diff --git a/model_server/app/tests/test_app.py b/model_server/app/tests/test_app.py index c91fc153..208bac2a 100644 --- a/model_server/app/tests/test_app.py +++ b/model_server/app/tests/test_app.py @@ -13,6 +13,7 @@ client = TestClient(app) logger.info(f"Model will be loaded on device: {glb.DEVICE}") +# [TODO] Review: update the following code # Unit tests for the health check endpoint @pytest.mark.asyncio @patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' @@ -22,6 +23,7 @@ async def test_healthz(): assert response.json() == {"status": "ok"} +# [TODO] Review: update the following code # Unit test for the models endpoint @pytest.mark.asyncio @patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' @@ -32,6 +34,7 @@ async def test_models(): assert len(response.json()["data"]) > 0 +# [TODO] Review: update the following code # Unit test for embeddings endpoint @pytest.mark.asyncio @patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' @@ -46,6 +49,7 @@ async def test_embedding(): assert response.status_code == 400 +# [TODO] Review: update the following code # Unit test for the guard endpoint @pytest.mark.asyncio @patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' @@ -56,6 +60,7 @@ async def test_guard(): assert "jailbreak_verdict" in response.json() +# [TODO] Review: update the following code # Unit test for the zero-shot endpoint @pytest.mark.asyncio @patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' @@ -73,6 +78,7 @@ async def test_zeroshot(): assert response.status_code == 400 +# [TODO] Review: update the following code # Unit test for the hallucination endpoint @pytest.mark.asyncio @patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' @@ -90,6 +96,7 @@ async def test_hallucination(): assert response.status_code == 400 +# [TODO] Review: update the following code # Unit test for the chat completion endpoint @pytest.mark.asyncio @patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' diff --git a/model_server/app/tests/test_cases.json b/model_server/app/tests/test_cases.json index 8fd7ec1e..d74328ee 100644 --- a/model_server/app/tests/test_cases.json +++ b/model_server/app/tests/test_cases.json @@ -1,794 +1,949 @@ -[{ - "case": "tool_call_halluciation", - "tokens" : [""], - "expect": 1, - "logprobs": [[-0.3333307206630707, - -1.5310522317886353, - -3.5098977088928223, - -3.9004578590393066, - -5.775152683258057, - -5.814209461212158, - -5.9574151039123535, - -6.0094895362854, - -6.0094895362854, - -6.673445224761963]] -}, -{ - "case" : "parameter_value_hallucination", - "expect" : 0, - "tokens" : ["", - "\n", - "{'", - "name", - "':", - " '", - "get", - "_current", - "_weather", - "',", - " '", - "arguments", - "':", - " {'", - "location", - "':", - " '", - "Sea", - ",", - " Australia", - "',", - " '", - "unit", - "':", - " '", - "c", - "elsius", - "',", - " '", - "days", - "':", - " '", - "1", - "'}}\n", - ""], - "logprobs": [[-0.008103232830762863, - -5.085402488708496, - -6.777836799621582, - -7.558959007263184, - -9.850253105163574, - -10.266852378845215, - -10.540244102478027, - -10.722506523132324, - -10.800618171691895, - -10.917786598205566], - [0.0, - -23.25142478942871, - -25.139137268066406, - -26.2847843170166, - -28.992677688598633, - -29.070789337158203, - -29.55248260498047, - -29.91700553894043, - -30.20341682434082, - -30.307567596435547], - [0.0, - -21.66313934326172, - -23.06916046142578, - -23.32953453063965, - -25.65988540649414, - -25.985353469848633, - -26.519121170043945, - -27.07892417907715, - -27.977216720581055, - -28.458908081054688], - [0.0, - -28.094383239746094, - -28.56305694580078, - -29.109844207763672, - -29.44832992553711, - -31.79170036315918, - -32.0, - -32.05207443237305, - -32.31244659423828, - -32.364524841308594], - [0.0, - -30.489830017089844, - -31.140766143798828, - -31.81774139404297, - -34.525634765625, - -35.8275032043457, - -36.504478454589844, - -39.05614471435547, - -40.123680114746094, - -40.696502685546875], - [0.0, - -25.646865844726562, - -26.66232681274414, - -27.781936645507812, - -28.979660034179688, - -31.140764236450195, - -31.92188835144043, - -31.973962783813477, - -33.04149627685547, - -33.58828353881836], - [0.0, - -23.511798858642578, - -24.136695861816406, - -25.230268478393555, - -25.777053833007812, - -25.80309295654297, - -26.45402717590332, - -26.636289596557617, - -26.740440368652344, - -26.896663665771484], - [0.0, - -22.366153717041016, - -24.683483123779297, - -26.610252380371094, - -26.610252380371094, - -27.313264846801758, - -27.67778778076172, - -28.510986328125, - -28.615135192871094, - -29.13588523864746], - [0.0, - -22.52237319946289, - -24.292919158935547, - -24.344993591308594, - -24.39706802368164, - -24.73555564880371, - -29.943042755126953, - -29.969079971313477, - -30.021154403686523, - -30.0341739654541], - [0.0, - -30.17738151550293, - -30.411718368530273, - -30.88039207458496, - -30.984540939331055, - -31.270952224731445, - -31.895851135253906, - -32.46867370605469, - -32.624900817871094, - -33.484134674072266], - [0.0, - -28.146459579467773, - -29.396255493164062, - -30.099267959594727, - -31.127744674682617, - -31.179821014404297, - -32.807159423828125, - -33.7445068359375, - -33.770545959472656, - -34.069976806640625], - [0.0, - -26.323841094970703, - -26.558177947998047, - -30.515867233276367, - -30.932466506958008, - -31.37510108947754, - -31.531326293945312, - -31.70056915283203, - -32.065093994140625, - -32.364524841308594], - [0.0, - -26.922698974609375, - -30.28152847290039, - -31.505287170410156, - -33.30187225341797, - -33.73148727416992, - -34.27827453613281, - -34.33034896850586, - -34.460533142089844, - -34.720909118652344], - [0.0, - -21.532955169677734, - -26.94873809814453, - -29.109848022460938, - -30.80228042602539, - -31.55736541748047, - -33.484134674072266, - -34.681854248046875, - -35.384864807128906, - -35.853538513183594], - [0.0, - -19.502033233642578, - -20.46541976928711, - -24.110658645629883, - -24.501218795776367, - -25.256305694580078, - -25.82912826538086, - -25.881202697753906, - -26.063465118408203, - -26.063465118408203], - [0.0, - -24.37103271484375, - -25.256305694580078, - -25.933277130126953, - -26.714401245117188, - -28.2506103515625, - -31.010576248168945, - -32.07810974121094, - -34.62977981567383, - -35.241661071777344], - [-1.1920922133867862e-06, - -14.398697853088379, - -14.424736976623535, - -17.158666610717773, - -17.41904067993164, - -18.200162887573242, - -18.434499740600586, - -18.66883659362793, - -19.71033477783203, - -19.71033477783203], - [-0.0001445904199499637, - -8.98305892944336, - -11.35246467590332, - -13.1490478515625, - -13.669795989990234, - -14.073375701904297, - -14.516012191772461, - -14.555068969726562, - -15.622602462768555, - -15.635622024536133], - [-0.44747352600097656, - -1.0202960968017578, - -8.467000961303711, - -10.914518356323242, - -11.25300407409668, - -11.435266494750977, - -12.346576690673828, - -13.075624465942383, - -13.12769889831543, - -13.231849670410156], - [-3.123767137527466, - -1.1188862323760986, - -1.639634370803833, - -2.0562336444854736, - -2.8633930683135986, - -2.9675419330596924, - -3.4882919788360596, - -3.69659161567688, - -4.217339515686035, - -4.243376731872559], - [-7.199982064776123e-05, - -9.76410961151123, - -11.144091606140137, - -16.507802963256836, - -17.132701873779297, - -17.44515037536621, - -17.9138240814209, - -18.33042335510254, - -18.9162654876709, - -19.39795684814453], - [0.0, - -22.991050720214844, - -23.824249267578125, - -24.969894409179688, - -25.46460723876953, - -25.829130172729492, - -26.480066299438477, - -26.909683227539062, - -27.33930206298828, - -27.391376495361328], - [-0.21928852796554565, - -1.625309705734253, - -9.775025367736816, - -12.977627754211426, - -16.388530731201172, - -17.091541290283203, - -19.044347763061523, - -19.38283348083496, - -19.460947036743164, - -19.59113311767578], - [0.0, - -24.006507873535156, - -27.443450927734375, - -27.729862213134766, - -28.12042236328125, - -28.276647567749023, - -28.927583694458008, - -30.099267959594727, - -31.479251861572266, - -32.07810974121094], - [0.0, - -18.17412567138672, - -18.772987365722656, - -21.689178466796875, - -21.92351531982422, - -23.7200984954834, - -23.79821014404297, - -23.79821014404297, - -24.032546997070312, - -25.308382034301758], - [-0.12947827577590942, - -2.1083219051361084, - -12.419143676757812, - -15.23118782043457, - -15.595710754394531, - -15.830047607421875, - -17.001731872558594, - -17.60059356689453, - -18.121341705322266, - -18.251529693603516], - [0.0, - -19.449962615966797, - -24.371034622192383, - -24.917821884155273, - -25.529701232910156, - -25.85516929626465, - -26.037429809570312, - -26.115543365478516, - -26.623271942138672, - -26.649309158325195], - [-0.03332124650478363, - -3.4181859493255615, - -15.759925842285156, - -15.812002182006836, - -16.593124389648438, - -17.894996643066406, - -18.09027671813965, - -18.79328727722168, - -19.144792556762695, - -20.147233963012695], - [0.0, - -21.142393112182617, - -22.157852172851562, - -23.511798858642578, - -24.657445907592773, - -25.021968841552734, - -25.5427188873291, - -25.59479331970215, - -25.75101661682129, - -25.95931625366211], - [0.0, - -23.04312515258789, - -24.94385528564453, - -26.323841094970703, - -27.54759979248047, - -28.563060760498047, - -29.786819458007812, - -30.620018005371094, - -30.69812774658203, - -31.08869171142578], - [0.0, - -26.167617797851562, - -28.771360397338867, - -29.55248260498047, - -30.906429290771484, - -31.114728927612305, - -31.414159774780273, - -31.622459411621094, - -31.713590621948242, - -31.726608276367188], - [-0.05012698099017143, - -3.018392562866211, - -11.740934371948242, - -13.146955490112305, - -13.797887802124023, - -14.943536758422852, - -16.037107467651367, - -16.375595092773438, - -16.714080810546875, - -17.36501693725586], - [-0.9704352021217346, - -0.7360983490943909, - -2.1941938400268555, - -4.225115776062012, - -5.0062360763549805, - -5.2666120529174805, - -5.839434623718262, - -7.2714948654174805, - -8.33902645111084, - -8.495253562927246], - [-0.014467108063399792, - -4.258565902709961, - -8.789079666137695, - -10.429437637329102, - -10.793962478637695, - -11.835458755493164, - -11.939607620239258, - -13.31959342956543, - -13.866378784179688, - -15.038063049316406], - [0.0, - -20.08787727355957, - -21.350692749023438, - -21.415786743164062, - -21.50691795349121, - -21.50691795349121, - -22.7176570892334, - -24.13669776916504, - -24.188772201538086, - -24.34499740600586]] -}, -{ - "case": "fail_case", - "expect" : 0, - "tokens" : ["", - "\n", - "{'", - "name", - "':", - " '", - "get", - "_current", - "_weather", - "',", - " '", - "arguments", - "':", - " {'", - "location", - "':", - " '", - "Seattle", - ",", - " WA", - "',", - " '", - "unit", - "':", - " '", - "c", - "elsius", - "',", - " '", - "days", - "':", - " '", - "7", - "'}}\n", - ""], - "logprobs":[[-0.00013815402053296566, - -9.113236427307129, - -10.571331977844238, - -14.099404335021973, - -14.28166675567627, - -15.583537101745605, - -15.81787395477295, - -16.143341064453125, - -16.143341064453125, - -16.260509490966797], - [0.0, - -26.896663665771484, - -27.32628059387207, - -27.41741180419922, - -32.07810974121094, - -32.07810974121094, - -32.28641128540039, - -32.29943084716797, - -32.44263458251953, - -32.520748138427734], - [0.0, - -22.444263458251953, - -24.527257919311523, - -27.15703773498535, - -28.016273498535156, - -28.2506103515625, - -28.693246841430664, - -29.070789337158203, - -29.565500259399414, - -29.812854766845703], - [0.0, - -27.860050201416016, - -28.641170501708984, - -29.448333740234375, - -30.932466506958008, - -31.63547706604004, - -32.33848571777344, - -32.85923767089844, - -33.17168426513672, - -33.45809555053711], - [0.0, - -31.81774139404297, - -31.895854949951172, - -32.05207824707031, - -35.43694305419922, - -36.3482551574707, - -38.61351013183594, - -39.26444625854492, - -40.61839294433594, - -41.71196365356445], - [0.0, - -27.33930206298828, - -27.834014892578125, - -28.849472045898438, - -30.567943572998047, - -32.98942565917969, - -33.067535400390625, - -33.067535400390625, - -35.67127990722656, - -35.69731903076172], - [0.0, - -25.33441925048828, - -26.063465118408203, - -26.219690322875977, - -26.2457275390625, - -26.53213882446289, - -27.365337371826172, - -28.354759216308594, - -28.667207717895508, - -28.74532127380371], - [0.0, - -24.423107147216797, - -24.579330444335938, - -26.81855010986328, - -28.12042236328125, - -28.32872200012207, - -28.61513328552246, - -29.16191864013672, - -29.187957763671875, - -29.240032196044922], - [0.0, - -22.027664184570312, - -23.850284576416016, - -23.980472564697266, - -24.292922973632812, - -24.787633895874023, - -29.279088973999023, - -29.55248260498047, - -29.903987884521484, - -30.190399169921875], - [0.0, - -31.609439849853516, - -31.817739486694336, - -32.54678726196289, - -32.676971435546875, - -32.781124114990234, - -32.98942565917969, - -33.106590270996094, - -33.57526397705078, - -34.369407653808594], - [0.0, - -29.34418296813965, - -29.63059425354004, - -30.021156311035156, - -30.984540939331055, - -33.21073913574219, - -34.30431365966797, - -34.56468963623047, - -34.70789337158203, - -34.79902648925781], - [0.0, - -25.438566207885742, - -25.69894027709961, - -30.190397262573242, - -30.802276611328125, - -31.58340072631836, - -31.609437942504883, - -31.64849281311035, - -31.973960876464844, - -32.29943084716797], - [0.0, - -27.157039642333984, - -32.104148864746094, - -32.33848571777344, - -34.04393768310547, - -34.12205505371094, - -34.40846252441406, - -34.42148208618164, - -34.772987365722656, - -34.87713623046875], - [0.0, - -24.813671112060547, - -26.974777221679688, - -31.010578155517578, - -31.08869171142578, - -32.1822624206543, - -35.33279037475586, - -35.489013671875, - -36.999183654785156, - -37.88446044921875], - [0.0, - -20.46541976928711, - -20.647682189941406, - -23.069164276123047, - -24.136699676513672, - -25.438570022583008, - -25.646869659423828, - -26.193655014038086, - -26.297805786132812, - -26.506103515625], - [0.0, - -27.18307113647461, - -28.30268096923828, - -28.56305694580078, - -29.526439666748047, - -32.416595458984375, - -35.202598571777344, - -36.426361083984375, - -39.31651306152344, - -39.38160705566406], - [0.0, - -18.7469482421875, - -20.100894927978516, - -21.402767181396484, - -21.428804397583008, - -22.20992660522461, - -22.34011459350586, - -22.730674743652344, - -23.069162368774414, - -23.980472564697266], - [-3.576278118089249e-07, - -15.2579345703125, - -16.481693267822266, - -17.991863250732422, - -19.215621948242188, - -20.25712013244629, - -21.350692749023438, - -22.314077377319336, - -22.496337890625, - -22.938974380493164], - [-0.08506780862808228, - -2.506549835205078, - -14.848289489746094, - -15.473188400268555, - -16.33242416381836, - -16.358461380004883, - -16.566761016845703, - -17.03543472290039, - -17.686370849609375, - -17.816556930541992], - [-0.0194891095161438, - -4.445854187011719, - -5.591499328613281, - -5.956024169921875, - -6.685070037841797, - -13.142353057861328, - -13.558952331542969, - -15.173273086547852, - -15.303461074829102, - -15.85024642944336], - [-0.0005990855861455202, - -7.4212646484375, - -15.675132751464844, - -15.72720718383789, - -16.76870346069336, - -16.76870346069336, - -17.706050872802734, - -18.669435501098633, - -19.398483276367188, - -19.658857345581055], - [0.0, - -24.110658645629883, - -25.829130172729492, - -26.011390686035156, - -26.011390686035156, - -26.532140731811523, - -26.58421516418457, - -27.651750564575195, - -27.75589942932129, - -28.055330276489258], - [-1.1408883333206177, - -0.38580334186553955, - -7.494022369384766, - -12.519245147705078, - -14.576202392578125, - -16.034297943115234, - -16.945608139038086, - -17.908992767333984, - -18.664077758789062, - -19.34105110168457], - [0.0, - -26.688365936279297, - -29.83889389038086, - -30.177383422851562, - -30.64605712890625, - -31.244916915893555, - -31.270954132080078, - -32.83319854736328, - -34.655818939208984, - -34.89015579223633], - [0.0, - -18.929210662841797, - -19.16354751586914, - -23.589908599853516, - -24.683481216430664, - -24.995929718017578, - -25.516677856445312, - -25.542715072631836, - -25.77705192565918, - -26.063465118408203], - [-0.2519786059856415, - -1.5017764568328857, - -12.437495231628418, - -15.457839012145996, - -15.744250297546387, - -16.837820053100586, - -17.41064453125, - -17.56686782836914, - -17.61894416809082, - -18.035541534423828], - [0.0, - -20.517494201660156, - -24.683483123779297, - -25.67290496826172, - -26.58421516418457, - -27.651750564575195, - -27.781936645507812, - -27.912124633789062, - -28.09438705444336, - -28.445892333984375], - [-3.40932747349143e-05, - -10.284820556640625, - -18.252273559570312, - -20.17904281616211, - -21.663175582885742, - -22.027700424194336, - -22.288074493408203, - -22.704673767089844, - -23.12127113342285, - -23.277496337890625], - [0.0, - -22.60049057006836, - -25.46460723876953, - -25.829130172729492, - -26.063467025756836, - -27.287227630615234, - -27.391376495361328, - -27.4694881439209, - -27.67778778076172, - -28.055330276489258], - [0.0, - -23.902362823486328, - -28.823436737060547, - -29.240036010742188, - -29.31814956665039, - -29.917007446289062, - -30.021160125732422, - -31.21887969970703, - -32.416603088378906, - -32.416603088378906], - [0.0, - -28.641170501708984, - -31.947925567626953, - -32.59886169433594, - -33.848655700683594, - -34.109031677246094, - -34.73393249511719, - -35.02033996582031, - -35.02033996582031, - -36.074859619140625], - [-0.013183215633034706, - -4.335395336151123, - -19.619365692138672, - -20.035964965820312, - -20.244266510009766, - -21.311800003051758, - -21.441987991333008, - -22.561595916748047, - -23.108383178710938, - -23.264606475830078], - [-8.344646857949556e-07, - -14.190400123596191, - -15.9088716506958, - -18.17412567138672, - -18.46053695678711, - -18.46053695678711, - -18.512611389160156, - -18.90317153930664, - -19.059398651123047, - -19.085433959960938], - [0.0, - -17.70545196533203, - -18.903175354003906, - -20.829944610595703, - -22.574451446533203, - -22.860862731933594, - -23.069162368774414, - -23.32953643798828, - -23.694061279296875, - -24.188772201538086], - [0.0, - -20.022781372070312, - -21.038240432739258, - -21.220502853393555, - -22.496337890625, - -22.769729614257812, - -23.589908599853516, - -23.65500259399414, - -23.94141387939453, - -24.266881942749023]] -} +[ + { + "case": "tool_call_halluciation", + "tokens": [ + "" + ], + "expect": 1, + "logprobs": [ + [ + -0.3333307206630707, + -1.5310522317886353, + -3.5098977088928223, + -3.9004578590393066, + -5.775152683258057, + -5.814209461212158, + -5.9574151039123535, + -6.0094895362854, + -6.0094895362854, + -6.673445224761963 + ] + ] + }, + { + "case": "parameter_value_hallucination", + "expect": 0, + "tokens": [ + "", + "\n", + "{'", + "name", + "':", + " '", + "get", + "_current", + "_weather", + "',", + " '", + "arguments", + "':", + " {'", + "location", + "':", + " '", + "Sea", + ",", + " Australia", + "',", + " '", + "unit", + "':", + " '", + "c", + "elsius", + "',", + " '", + "days", + "':", + " '", + "1", + "'}}\n", + "" + ], + "logprobs": [ + [ + -0.008103232830762863, + -5.085402488708496, + -6.777836799621582, + -7.558959007263184, + -9.850253105163574, + -10.266852378845215, + -10.540244102478027, + -10.722506523132324, + -10.800618171691895, + -10.917786598205566 + ], + [ + 0.0, + -23.25142478942871, + -25.139137268066406, + -26.2847843170166, + -28.992677688598633, + -29.070789337158203, + -29.55248260498047, + -29.91700553894043, + -30.20341682434082, + -30.307567596435547 + ], + [ + 0.0, + -21.66313934326172, + -23.06916046142578, + -23.32953453063965, + -25.65988540649414, + -25.985353469848633, + -26.519121170043945, + -27.07892417907715, + -27.977216720581055, + -28.458908081054688 + ], + [ + 0.0, + -28.094383239746094, + -28.56305694580078, + -29.109844207763672, + -29.44832992553711, + -31.79170036315918, + -32.0, + -32.05207443237305, + -32.31244659423828, + -32.364524841308594 + ], + [ + 0.0, + -30.489830017089844, + -31.140766143798828, + -31.81774139404297, + -34.525634765625, + -35.8275032043457, + -36.504478454589844, + -39.05614471435547, + -40.123680114746094, + -40.696502685546875 + ], + [ + 0.0, + -25.646865844726562, + -26.66232681274414, + -27.781936645507812, + -28.979660034179688, + -31.140764236450195, + -31.92188835144043, + -31.973962783813477, + -33.04149627685547, + -33.58828353881836 + ], + [ + 0.0, + -23.511798858642578, + -24.136695861816406, + -25.230268478393555, + -25.777053833007812, + -25.80309295654297, + -26.45402717590332, + -26.636289596557617, + -26.740440368652344, + -26.896663665771484 + ], + [ + 0.0, + -22.366153717041016, + -24.683483123779297, + -26.610252380371094, + -26.610252380371094, + -27.313264846801758, + -27.67778778076172, + -28.510986328125, + -28.615135192871094, + -29.13588523864746 + ], + [ + 0.0, + -22.52237319946289, + -24.292919158935547, + -24.344993591308594, + -24.39706802368164, + -24.73555564880371, + -29.943042755126953, + -29.969079971313477, + -30.021154403686523, + -30.0341739654541 + ], + [ + 0.0, + -30.17738151550293, + -30.411718368530273, + -30.88039207458496, + -30.984540939331055, + -31.270952224731445, + -31.895851135253906, + -32.46867370605469, + -32.624900817871094, + -33.484134674072266 + ], + [ + 0.0, + -28.146459579467773, + -29.396255493164062, + -30.099267959594727, + -31.127744674682617, + -31.179821014404297, + -32.807159423828125, + -33.7445068359375, + -33.770545959472656, + -34.069976806640625 + ], + [ + 0.0, + -26.323841094970703, + -26.558177947998047, + -30.515867233276367, + -30.932466506958008, + -31.37510108947754, + -31.531326293945312, + -31.70056915283203, + -32.065093994140625, + -32.364524841308594 + ], + [ + 0.0, + -26.922698974609375, + -30.28152847290039, + -31.505287170410156, + -33.30187225341797, + -33.73148727416992, + -34.27827453613281, + -34.33034896850586, + -34.460533142089844, + -34.720909118652344 + ], + [ + 0.0, + -21.532955169677734, + -26.94873809814453, + -29.109848022460938, + -30.80228042602539, + -31.55736541748047, + -33.484134674072266, + -34.681854248046875, + -35.384864807128906, + -35.853538513183594 + ], + [ + 0.0, + -19.502033233642578, + -20.46541976928711, + -24.110658645629883, + -24.501218795776367, + -25.256305694580078, + -25.82912826538086, + -25.881202697753906, + -26.063465118408203, + -26.063465118408203 + ], + [ + 0.0, + -24.37103271484375, + -25.256305694580078, + -25.933277130126953, + -26.714401245117188, + -28.2506103515625, + -31.010576248168945, + -32.07810974121094, + -34.62977981567383, + -35.241661071777344 + ], + [ + -1.1920922133867862e-06, + -14.398697853088379, + -14.424736976623535, + -17.158666610717773, + -17.41904067993164, + -18.200162887573242, + -18.434499740600586, + -18.66883659362793, + -19.71033477783203, + -19.71033477783203 + ], + [ + -0.0001445904199499637, + -8.98305892944336, + -11.35246467590332, + -13.1490478515625, + -13.669795989990234, + -14.073375701904297, + -14.516012191772461, + -14.555068969726562, + -15.622602462768555, + -15.635622024536133 + ], + [ + -0.44747352600097656, + -1.0202960968017578, + -8.467000961303711, + -10.914518356323242, + -11.25300407409668, + -11.435266494750977, + -12.346576690673828, + -13.075624465942383, + -13.12769889831543, + -13.231849670410156 + ], + [ + -3.123767137527466, + -1.1188862323760986, + -1.639634370803833, + -2.0562336444854736, + -2.8633930683135986, + -2.9675419330596924, + -3.4882919788360596, + -3.69659161567688, + -4.217339515686035, + -4.243376731872559 + ], + [ + -7.199982064776123e-05, + -9.76410961151123, + -11.144091606140137, + -16.507802963256836, + -17.132701873779297, + -17.44515037536621, + -17.9138240814209, + -18.33042335510254, + -18.9162654876709, + -19.39795684814453 + ], + [ + 0.0, + -22.991050720214844, + -23.824249267578125, + -24.969894409179688, + -25.46460723876953, + -25.829130172729492, + -26.480066299438477, + -26.909683227539062, + -27.33930206298828, + -27.391376495361328 + ], + [ + -0.21928852796554565, + -1.625309705734253, + -9.775025367736816, + -12.977627754211426, + -16.388530731201172, + -17.091541290283203, + -19.044347763061523, + -19.38283348083496, + -19.460947036743164, + -19.59113311767578 + ], + [ + 0.0, + -24.006507873535156, + -27.443450927734375, + -27.729862213134766, + -28.12042236328125, + -28.276647567749023, + -28.927583694458008, + -30.099267959594727, + -31.479251861572266, + -32.07810974121094 + ], + [ + 0.0, + -18.17412567138672, + -18.772987365722656, + -21.689178466796875, + -21.92351531982422, + -23.7200984954834, + -23.79821014404297, + -23.79821014404297, + -24.032546997070312, + -25.308382034301758 + ], + [ + -0.12947827577590942, + -2.1083219051361084, + -12.419143676757812, + -15.23118782043457, + -15.595710754394531, + -15.830047607421875, + -17.001731872558594, + -17.60059356689453, + -18.121341705322266, + -18.251529693603516 + ], + [ + 0.0, + -19.449962615966797, + -24.371034622192383, + -24.917821884155273, + -25.529701232910156, + -25.85516929626465, + -26.037429809570312, + -26.115543365478516, + -26.623271942138672, + -26.649309158325195 + ], + [ + -0.03332124650478363, + -3.4181859493255615, + -15.759925842285156, + -15.812002182006836, + -16.593124389648438, + -17.894996643066406, + -18.09027671813965, + -18.79328727722168, + -19.144792556762695, + -20.147233963012695 + ], + [ + 0.0, + -21.142393112182617, + -22.157852172851562, + -23.511798858642578, + -24.657445907592773, + -25.021968841552734, + -25.5427188873291, + -25.59479331970215, + -25.75101661682129, + -25.95931625366211 + ], + [ + 0.0, + -23.04312515258789, + -24.94385528564453, + -26.323841094970703, + -27.54759979248047, + -28.563060760498047, + -29.786819458007812, + -30.620018005371094, + -30.69812774658203, + -31.08869171142578 + ], + [ + 0.0, + -26.167617797851562, + -28.771360397338867, + -29.55248260498047, + -30.906429290771484, + -31.114728927612305, + -31.414159774780273, + -31.622459411621094, + -31.713590621948242, + -31.726608276367188 + ], + [ + -0.05012698099017143, + -3.018392562866211, + -11.740934371948242, + -13.146955490112305, + -13.797887802124023, + -14.943536758422852, + -16.037107467651367, + -16.375595092773438, + -16.714080810546875, + -17.36501693725586 + ], + [ + -0.9704352021217346, + -0.7360983490943909, + -2.1941938400268555, + -4.225115776062012, + -5.0062360763549805, + -5.2666120529174805, + -5.839434623718262, + -7.2714948654174805, + -8.33902645111084, + -8.495253562927246 + ], + [ + -0.014467108063399792, + -4.258565902709961, + -8.789079666137695, + -10.429437637329102, + -10.793962478637695, + -11.835458755493164, + -11.939607620239258, + -13.31959342956543, + -13.866378784179688, + -15.038063049316406 + ], + [ + 0.0, + -20.08787727355957, + -21.350692749023438, + -21.415786743164062, + -21.50691795349121, + -21.50691795349121, + -22.7176570892334, + -24.13669776916504, + -24.188772201538086, + -24.34499740600586 + ] + ] + }, + { + "case": "fail_case", + "expect": 0, + "tokens": [ + "", + "\n", + "{'", + "name", + "':", + " '", + "get", + "_current", + "_weather", + "',", + " '", + "arguments", + "':", + " {'", + "location", + "':", + " '", + "Seattle", + ",", + " WA", + "',", + " '", + "unit", + "':", + " '", + "c", + "elsius", + "',", + " '", + "days", + "':", + " '", + "7", + "'}}\n", + "" + ], + "logprobs": [ + [ + -0.00013815402053296566, + -9.113236427307129, + -10.571331977844238, + -14.099404335021973, + -14.28166675567627, + -15.583537101745605, + -15.81787395477295, + -16.143341064453125, + -16.143341064453125, + -16.260509490966797 + ], + [ + 0.0, + -26.896663665771484, + -27.32628059387207, + -27.41741180419922, + -32.07810974121094, + -32.07810974121094, + -32.28641128540039, + -32.29943084716797, + -32.44263458251953, + -32.520748138427734 + ], + [ + 0.0, + -22.444263458251953, + -24.527257919311523, + -27.15703773498535, + -28.016273498535156, + -28.2506103515625, + -28.693246841430664, + -29.070789337158203, + -29.565500259399414, + -29.812854766845703 + ], + [ + 0.0, + -27.860050201416016, + -28.641170501708984, + -29.448333740234375, + -30.932466506958008, + -31.63547706604004, + -32.33848571777344, + -32.85923767089844, + -33.17168426513672, + -33.45809555053711 + ], + [ + 0.0, + -31.81774139404297, + -31.895854949951172, + -32.05207824707031, + -35.43694305419922, + -36.3482551574707, + -38.61351013183594, + -39.26444625854492, + -40.61839294433594, + -41.71196365356445 + ], + [ + 0.0, + -27.33930206298828, + -27.834014892578125, + -28.849472045898438, + -30.567943572998047, + -32.98942565917969, + -33.067535400390625, + -33.067535400390625, + -35.67127990722656, + -35.69731903076172 + ], + [ + 0.0, + -25.33441925048828, + -26.063465118408203, + -26.219690322875977, + -26.2457275390625, + -26.53213882446289, + -27.365337371826172, + -28.354759216308594, + -28.667207717895508, + -28.74532127380371 + ], + [ + 0.0, + -24.423107147216797, + -24.579330444335938, + -26.81855010986328, + -28.12042236328125, + -28.32872200012207, + -28.61513328552246, + -29.16191864013672, + -29.187957763671875, + -29.240032196044922 + ], + [ + 0.0, + -22.027664184570312, + -23.850284576416016, + -23.980472564697266, + -24.292922973632812, + -24.787633895874023, + -29.279088973999023, + -29.55248260498047, + -29.903987884521484, + -30.190399169921875 + ], + [ + 0.0, + -31.609439849853516, + -31.817739486694336, + -32.54678726196289, + -32.676971435546875, + -32.781124114990234, + -32.98942565917969, + -33.106590270996094, + -33.57526397705078, + -34.369407653808594 + ], + [ + 0.0, + -29.34418296813965, + -29.63059425354004, + -30.021156311035156, + -30.984540939331055, + -33.21073913574219, + -34.30431365966797, + -34.56468963623047, + -34.70789337158203, + -34.79902648925781 + ], + [ + 0.0, + -25.438566207885742, + -25.69894027709961, + -30.190397262573242, + -30.802276611328125, + -31.58340072631836, + -31.609437942504883, + -31.64849281311035, + -31.973960876464844, + -32.29943084716797 + ], + [ + 0.0, + -27.157039642333984, + -32.104148864746094, + -32.33848571777344, + -34.04393768310547, + -34.12205505371094, + -34.40846252441406, + -34.42148208618164, + -34.772987365722656, + -34.87713623046875 + ], + [ + 0.0, + -24.813671112060547, + -26.974777221679688, + -31.010578155517578, + -31.08869171142578, + -32.1822624206543, + -35.33279037475586, + -35.489013671875, + -36.999183654785156, + -37.88446044921875 + ], + [ + 0.0, + -20.46541976928711, + -20.647682189941406, + -23.069164276123047, + -24.136699676513672, + -25.438570022583008, + -25.646869659423828, + -26.193655014038086, + -26.297805786132812, + -26.506103515625 + ], + [ + 0.0, + -27.18307113647461, + -28.30268096923828, + -28.56305694580078, + -29.526439666748047, + -32.416595458984375, + -35.202598571777344, + -36.426361083984375, + -39.31651306152344, + -39.38160705566406 + ], + [ + 0.0, + -18.7469482421875, + -20.100894927978516, + -21.402767181396484, + -21.428804397583008, + -22.20992660522461, + -22.34011459350586, + -22.730674743652344, + -23.069162368774414, + -23.980472564697266 + ], + [ + -3.576278118089249e-07, + -15.2579345703125, + -16.481693267822266, + -17.991863250732422, + -19.215621948242188, + -20.25712013244629, + -21.350692749023438, + -22.314077377319336, + -22.496337890625, + -22.938974380493164 + ], + [ + -0.08506780862808228, + -2.506549835205078, + -14.848289489746094, + -15.473188400268555, + -16.33242416381836, + -16.358461380004883, + -16.566761016845703, + -17.03543472290039, + -17.686370849609375, + -17.816556930541992 + ], + [ + -0.0194891095161438, + -4.445854187011719, + -5.591499328613281, + -5.956024169921875, + -6.685070037841797, + -13.142353057861328, + -13.558952331542969, + -15.173273086547852, + -15.303461074829102, + -15.85024642944336 + ], + [ + -0.0005990855861455202, + -7.4212646484375, + -15.675132751464844, + -15.72720718383789, + -16.76870346069336, + -16.76870346069336, + -17.706050872802734, + -18.669435501098633, + -19.398483276367188, + -19.658857345581055 + ], + [ + 0.0, + -24.110658645629883, + -25.829130172729492, + -26.011390686035156, + -26.011390686035156, + -26.532140731811523, + -26.58421516418457, + -27.651750564575195, + -27.75589942932129, + -28.055330276489258 + ], + [ + -1.1408883333206177, + -0.38580334186553955, + -7.494022369384766, + -12.519245147705078, + -14.576202392578125, + -16.034297943115234, + -16.945608139038086, + -17.908992767333984, + -18.664077758789062, + -19.34105110168457 + ], + [ + 0.0, + -26.688365936279297, + -29.83889389038086, + -30.177383422851562, + -30.64605712890625, + -31.244916915893555, + -31.270954132080078, + -32.83319854736328, + -34.655818939208984, + -34.89015579223633 + ], + [ + 0.0, + -18.929210662841797, + -19.16354751586914, + -23.589908599853516, + -24.683481216430664, + -24.995929718017578, + -25.516677856445312, + -25.542715072631836, + -25.77705192565918, + -26.063465118408203 + ], + [ + -0.2519786059856415, + -1.5017764568328857, + -12.437495231628418, + -15.457839012145996, + -15.744250297546387, + -16.837820053100586, + -17.41064453125, + -17.56686782836914, + -17.61894416809082, + -18.035541534423828 + ], + [ + 0.0, + -20.517494201660156, + -24.683483123779297, + -25.67290496826172, + -26.58421516418457, + -27.651750564575195, + -27.781936645507812, + -27.912124633789062, + -28.09438705444336, + -28.445892333984375 + ], + [ + -3.40932747349143e-05, + -10.284820556640625, + -18.252273559570312, + -20.17904281616211, + -21.663175582885742, + -22.027700424194336, + -22.288074493408203, + -22.704673767089844, + -23.12127113342285, + -23.277496337890625 + ], + [ + 0.0, + -22.60049057006836, + -25.46460723876953, + -25.829130172729492, + -26.063467025756836, + -27.287227630615234, + -27.391376495361328, + -27.4694881439209, + -27.67778778076172, + -28.055330276489258 + ], + [ + 0.0, + -23.902362823486328, + -28.823436737060547, + -29.240036010742188, + -29.31814956665039, + -29.917007446289062, + -30.021160125732422, + -31.21887969970703, + -32.416603088378906, + -32.416603088378906 + ], + [ + 0.0, + -28.641170501708984, + -31.947925567626953, + -32.59886169433594, + -33.848655700683594, + -34.109031677246094, + -34.73393249511719, + -35.02033996582031, + -35.02033996582031, + -36.074859619140625 + ], + [ + -0.013183215633034706, + -4.335395336151123, + -19.619365692138672, + -20.035964965820312, + -20.244266510009766, + -21.311800003051758, + -21.441987991333008, + -22.561595916748047, + -23.108383178710938, + -23.264606475830078 + ], + [ + -8.344646857949556e-07, + -14.190400123596191, + -15.9088716506958, + -18.17412567138672, + -18.46053695678711, + -18.46053695678711, + -18.512611389160156, + -18.90317153930664, + -19.059398651123047, + -19.085433959960938 + ], + [ + 0.0, + -17.70545196533203, + -18.903175354003906, + -20.829944610595703, + -22.574451446533203, + -22.860862731933594, + -23.069162368774414, + -23.32953643798828, + -23.694061279296875, + -24.188772201538086 + ], + [ + 0.0, + -20.022781372070312, + -21.038240432739258, + -21.220502853393555, + -22.496337890625, + -22.769729614257812, + -23.589908599853516, + -23.65500259399414, + -23.94141387939453, + -24.266881942749023 + ] + ] + } ] diff --git a/model_server/app/tests/test_cli_stop_server.py b/model_server/app/tests/test_cli_stop_server.py index 4f3955a7..d5bad43f 100644 --- a/model_server/app/tests/test_cli_stop_server.py +++ b/model_server/app/tests/test_cli_stop_server.py @@ -1,7 +1,6 @@ import unittest + from unittest.mock import patch, MagicMock -import subprocess -import time from app.cli import kill_process diff --git a/model_server/app/tests/test_function_calling.py b/model_server/app/tests/test_function_calling.py index 251007d3..b3c7bf2f 100644 --- a/model_server/app/tests/test_function_calling.py +++ b/model_server/app/tests/test_function_calling.py @@ -1,13 +1,12 @@ +import json import pytest -from unittest.mock import AsyncMock, MagicMock, patch -import app.commons.constants as const + from fastapi import Response -from app.function_calling.model_utils import ( - process_messages, - chat_completion, +from unittest.mock import AsyncMock, MagicMock, patch +from app.commons.globals import handler_map +from app.model_handler.function_calling import ( Message, ChatMessage, - Choice, ChatCompletionResponse, ) @@ -31,14 +30,27 @@ def sample_messages(): def sample_request(sample_messages): return ChatMessage( messages=sample_messages, - tools=[{"name": "sample_tool", "description": "A sample tool"}], + tools=[ + { + "type": "function", + "function": { + "name": "sample_tool", + "description": "A sample tool", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + } + ], ) -@patch("app.commons.constants.arch_function_hanlder") +@patch("app.commons.globals.handler_map") def test_process_messages(mock_hanlder): messages = sample_messages() - processed = process_messages(messages) + processed = handler_map["Arch-Function"]._process_messages(messages) assert len(processed) == 3 assert processed[0] == {"role": "user", "content": "Hello!"} @@ -48,10 +60,11 @@ def test_process_messages(mock_hanlder): } assert processed[2] == { "role": "user", - "content": "\nResponse from tool\n", + "content": f"\n{json.dumps('Response from tool')}\n", } +# [TODO] Review: Update the following test @patch("app.commons.constants.arch_function_client") @patch("app.commons.constants.arch_function_hanlder") @pytest.mark.asyncio diff --git a/model_server/app/tests/test_guardrails.py b/model_server/app/tests/test_guardrails.py new file mode 100644 index 00000000..c490a662 --- /dev/null +++ b/model_server/app/tests/test_guardrails.py @@ -0,0 +1,47 @@ +import os +import pytest +from unittest.mock import patch, MagicMock +import app.commons.globals as glb + +# Mock constants +glb.DEVICE = "cpu" # Adjust as needed for your test case +arch_guard_model_type = { + "cpu": "katanemo/Arch-Guard-cpu", + "cuda": "katanemo/Arch-Guard", + "mps": "katanemo/Arch-Guard", +} + + +# [TODO] Review: update the following code to test under `cpu`, `cuda`, and `mps` +# Test for get_prompt_guard function +@patch("app.loader.AutoTokenizer.from_pretrained") +@patch("app.loader.OVModelForSequenceClassification.from_pretrained") +@patch("app.loader.AutoModelForSequenceClassification.from_pretrained") +def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer): + # Mock model based on device + if glb.DEVICE == "cpu": + mock_ov_model.return_value = MagicMock() + else: + mock_auto_model.return_value = MagicMock() + + mock_tokenizer.return_value = MagicMock() + + prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE]) + + # Assertions + assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE] + mock_tokenizer.assert_called_once_with( + arch_guard_model_type[glb.DEVICE], trust_remote_code=True + ) + if glb.DEVICE == "cpu": + mock_ov_model.assert_called_once_with( + arch_guard_model_type[glb.DEVICE], + device_map=glb.DEVICE, + low_cpu_mem_usage=True, + ) + else: + mock_auto_model.assert_called_once_with( + arch_guard_model_type[glb.DEVICE], + device_map=glb.DEVICE, + low_cpu_mem_usage=True, + ) diff --git a/model_server/app/tests/test_hallucination.py b/model_server/app/tests/test_hallucination.py index 8b6c387e..2afb85c9 100644 --- a/model_server/app/tests/test_hallucination.py +++ b/model_server/app/tests/test_hallucination.py @@ -1,8 +1,11 @@ import json -from app.function_calling.hallucination_handler import HallucinationStateHandler import pytest import os + +from app.model_handler.hallucination_handler import HallucinationStateHandler + + # Get the directory of the current file current_dir = os.path.dirname(__file__) @@ -45,6 +48,7 @@ if type(function_description) != list: function_description = [get_weather_api["function"]] +# [TODO] Review: update the following code @pytest.mark.parametrize("case", test_cases) def test_hallucination(case): state = HallucinationStateHandler( @@ -58,6 +62,7 @@ def test_hallucination(case): assert state.hallucination == case["expect"] +# [TODO] Review: update the following code @pytest.mark.parametrize("is_hallucinate_sample", [True, False]) def test_hallucination_prompt(is_hallucinate_sample): TASK_PROMPT = """ diff --git a/model_server/app/tests/test_loaders_cpu.py b/model_server/app/tests/test_loaders_cpu.py deleted file mode 100644 index ce9bf5d2..00000000 --- a/model_server/app/tests/test_loaders_cpu.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -import pytest -from unittest.mock import patch, MagicMock -import app.commons.globals as glb -from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard - -# Mock constants -glb.DEVICE = "cpu" # Adjust as needed for your test case -arch_guard_model_type = { - "cpu": "katanemo/Arch-Guard-cpu", - "cuda": "katanemo/Arch-Guard", - "mps": "katanemo/Arch-Guard", -} - - -@pytest.fixture -def mock_env(): - # Mock environment variables - os.environ["MODELS"] = "katanemo/bge-large-en-v1.5" - os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli" - - -# Test for get_embedding_model function -@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained") -@patch("app.loader.AutoModel.from_pretrained") -@patch("app.loader.AutoTokenizer.from_pretrained") -def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env): - mock_automodel.return_value = MagicMock() - mock_ort_model.return_value = MagicMock() - mock_tokenizer.return_value = MagicMock() - - embedding_model = get_embedding_model() - - # Assertions - assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5" - mock_tokenizer.assert_called_once_with( - "katanemo/bge-large-en-v1.5", trust_remote_code=True - ) - if glb.DEVICE != "cuda": - mock_ort_model.assert_called_once_with( - "katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx" - ) - else: - mock_automodel.assert_called_once_with( - "katanemo/bge-large-en-v1.5", device_map=glb.DEVICE - ) - - -# Test for get_zero_shot_model function -@patch("app.loader.ORTModelForSequenceClassification.from_pretrained") -@patch("app.loader.pipeline") -@patch("app.loader.AutoTokenizer.from_pretrained") -def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env): - mock_pipeline.return_value = MagicMock() - mock_ort_model.return_value = MagicMock() - mock_tokenizer.return_value = MagicMock() - - zero_shot_model = get_zero_shot_model() - - # Assertions - assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli" - mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli") - if glb.DEVICE != "cuda": - mock_ort_model.assert_called_once_with( - "katanemo/bart-large-mnli", file_name="onnx/model.onnx" - ) - else: - assert mock_pipeline.called_once() - - -# Test for get_prompt_guard function -@patch("app.loader.AutoTokenizer.from_pretrained") -@patch("app.loader.OVModelForSequenceClassification.from_pretrained") -@patch("app.loader.AutoModelForSequenceClassification.from_pretrained") -def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer): - # Mock model based on device - if glb.DEVICE == "cpu": - mock_ov_model.return_value = MagicMock() - else: - mock_auto_model.return_value = MagicMock() - - mock_tokenizer.return_value = MagicMock() - - prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE]) - - # Assertions - assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE] - mock_tokenizer.assert_called_once_with( - arch_guard_model_type[glb.DEVICE], trust_remote_code=True - ) - if glb.DEVICE == "cpu": - mock_ov_model.assert_called_once_with( - arch_guard_model_type[glb.DEVICE], - device_map=glb.DEVICE, - low_cpu_mem_usage=True, - ) - else: - mock_auto_model.assert_called_once_with( - arch_guard_model_type[glb.DEVICE], - device_map=glb.DEVICE, - low_cpu_mem_usage=True, - ) diff --git a/model_server/app/tests/test_loaders_gpu.py b/model_server/app/tests/test_loaders_gpu.py deleted file mode 100644 index 4d5875e9..00000000 --- a/model_server/app/tests/test_loaders_gpu.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -import pytest -from unittest.mock import patch, MagicMock -import app.commons.globals as glb -from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard - -# Mock constants -glb.DEVICE = "cuda" # Adjust as needed for your test case -arch_guard_model_type = { - "cpu": "katanemo/Arch-Guard-cpu", - "cuda": "katanemo/Arch-Guard", - "mps": "katanemo/Arch-Guard", -} - - -@pytest.fixture -def mock_env(): - # Mock environment variables - os.environ["MODELS"] = "katanemo/bge-large-en-v1.5" - os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli" - - -# Test for get_embedding_model function -@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained") -@patch("app.loader.AutoModel.from_pretrained") -@patch("app.loader.AutoTokenizer.from_pretrained") -def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env): - mock_automodel.return_value = MagicMock() - mock_ort_model.return_value = MagicMock() - mock_tokenizer.return_value = MagicMock() - - embedding_model = get_embedding_model() - - # Assertions - assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5" - mock_tokenizer.assert_called_once_with( - "katanemo/bge-large-en-v1.5", trust_remote_code=True - ) - if glb.DEVICE != "cuda": - mock_ort_model.assert_called_once_with( - "katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx" - ) - else: - mock_automodel.assert_called_once_with( - "katanemo/bge-large-en-v1.5", device_map=glb.DEVICE - ) - - -# Test for get_zero_shot_model function -@patch("app.loader.ORTModelForSequenceClassification.from_pretrained") -@patch("app.loader.pipeline") -@patch("app.loader.AutoTokenizer.from_pretrained") -def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env): - mock_pipeline.return_value = MagicMock() - mock_ort_model.return_value = MagicMock() - mock_tokenizer.return_value = MagicMock() - - zero_shot_model = get_zero_shot_model() - - # Assertions - assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli" - mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli") - if glb.DEVICE != "cuda": - mock_ort_model.assert_called_once_with( - "katanemo/bart-large-mnli", file_name="onnx/model.onnx" - ) - else: - assert mock_pipeline.called_once() - - -# Test for get_prompt_guard function -@patch("app.loader.AutoTokenizer.from_pretrained") -@patch("app.loader.OVModelForSequenceClassification.from_pretrained") -@patch("app.loader.AutoModelForSequenceClassification.from_pretrained") -def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer): - # Mock model based on device - if glb.DEVICE == "cpu": - mock_ov_model.return_value = MagicMock() - else: - mock_auto_model.return_value = MagicMock() - - mock_tokenizer.return_value = MagicMock() - - prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE]) - - # Assertions - assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE] - mock_tokenizer.assert_called_once_with( - arch_guard_model_type[glb.DEVICE], trust_remote_code=True - ) - if glb.DEVICE == "cpu": - mock_ov_model.assert_called_once_with( - arch_guard_model_type[glb.DEVICE], - device_map=glb.DEVICE, - low_cpu_mem_usage=True, - ) - else: - mock_auto_model.assert_called_once_with( - arch_guard_model_type[glb.DEVICE], - device_map=glb.DEVICE, - low_cpu_mem_usage=True, - ) diff --git a/model_server/app/tests/test_loaders_mps.py b/model_server/app/tests/test_loaders_mps.py deleted file mode 100644 index 41289c7d..00000000 --- a/model_server/app/tests/test_loaders_mps.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -import pytest -from unittest.mock import patch, MagicMock -import app.commons.globals as glb -from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard - -# Mock constants -glb.DEVICE = "mps" # Adjust as needed for your test case -arch_guard_model_type = { - "cpu": "katanemo/Arch-Guard-cpu", - "cuda": "katanemo/Arch-Guard", - "mps": "katanemo/Arch-Guard", -} - - -@pytest.fixture -def mock_env(): - # Mock environment variables - os.environ["MODELS"] = "katanemo/bge-large-en-v1.5" - os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli" - - -# Test for get_embedding_model function -@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained") -@patch("app.loader.AutoModel.from_pretrained") -@patch("app.loader.AutoTokenizer.from_pretrained") -def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env): - mock_automodel.return_value = MagicMock() - mock_ort_model.return_value = MagicMock() - mock_tokenizer.return_value = MagicMock() - - embedding_model = get_embedding_model() - - # Assertions - assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5" - mock_tokenizer.assert_called_once_with( - "katanemo/bge-large-en-v1.5", trust_remote_code=True - ) - if glb.DEVICE != "cuda": - mock_ort_model.assert_called_once_with( - "katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx" - ) - else: - mock_automodel.assert_called_once_with( - "katanemo/bge-large-en-v1.5", device_map=glb.DEVICE - ) - - -# Test for get_zero_shot_model function -@patch("app.loader.ORTModelForSequenceClassification.from_pretrained") -@patch("app.loader.pipeline") -@patch("app.loader.AutoTokenizer.from_pretrained") -def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env): - mock_pipeline.return_value = MagicMock() - mock_ort_model.return_value = MagicMock() - mock_tokenizer.return_value = MagicMock() - - zero_shot_model = get_zero_shot_model() - - # Assertions - assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli" - mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli") - if glb.DEVICE != "cuda": - mock_ort_model.assert_called_once_with( - "katanemo/bart-large-mnli", file_name="onnx/model.onnx" - ) - else: - assert mock_pipeline.called_once() - - -# Test for get_prompt_guard function -@patch("app.loader.AutoTokenizer.from_pretrained") -@patch("app.loader.OVModelForSequenceClassification.from_pretrained") -@patch("app.loader.AutoModelForSequenceClassification.from_pretrained") -def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer): - # Mock model based on device - if glb.DEVICE == "cpu": - mock_ov_model.return_value = MagicMock() - else: - mock_auto_model.return_value = MagicMock() - - mock_tokenizer.return_value = MagicMock() - - prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE]) - - # Assertions - assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE] - mock_tokenizer.assert_called_once_with( - arch_guard_model_type[glb.DEVICE], trust_remote_code=True - ) - if glb.DEVICE == "cpu": - mock_ov_model.assert_called_once_with( - arch_guard_model_type[glb.DEVICE], - device_map=glb.DEVICE, - low_cpu_mem_usage=True, - ) - else: - mock_auto_model.assert_called_once_with( - arch_guard_model_type[glb.DEVICE], - device_map=glb.DEVICE, - low_cpu_mem_usage=True, - ) diff --git a/model_server/app/tests/test_state.py b/model_server/app/tests/test_state.py index 9eb72c8c..dcb3a808 100644 --- a/model_server/app/tests/test_state.py +++ b/model_server/app/tests/test_state.py @@ -1,66 +1,50 @@ -from typing import List -import pytest -import json -from app.function_calling.model_utils import Message, process_messages +from app.commons.globals import handler_map +from app.model_handler.function_calling import Message -test_input_history = """ -[ + +test_input_history = [ + {"role": "user", "content": "how is the weather in chicago for next 5 days?"}, { - "role": "user", - "content": "how is the weather in chicago for next 5 days?" + "role": "assistant", + "model": "Arch-Function", + "tool_calls": [ + { + "id": "call_3394", + "type": "function", + "function": { + "name": "weather_forecast", + "arguments": {"city": "Chicago", "days": 5}, + }, + } + ], }, + {"role": "tool", "content": "--", "tool_call_id": "call_3394"}, + {"role": "assistant", "content": "--", "model": "gpt-3.5-turbo-0125"}, + {"role": "user", "content": "how is the weather in chicago for next 5 days?"}, { - "role": "assistant", - "model": "Arch-Function-1.5B", - "tool_calls": [ - { - "id": "call_3394", - "type": "function", - "function": { - "name": "weather_forecast", - "arguments": { "city": "Chicago", "days": 5 } - } - } - ] + "role": "assistant", + "tool_calls": [ + { + "id": "call_5306", + "type": "function", + "function": { + "name": "weather_forecast", + "arguments": {"city": "Chicago", "days": 5}, + }, + } + ], }, - { - "role": "tool", - "content": "--", - "tool_call_id": "call_3394" - }, - { - "role": "assistant", - "content": "--", - "model": "gpt-3.5-turbo-0125" - }, - { - "role": "user", - "content": "how is the weather in chicago for next 5 days?" - }, - { - "role": "assistant", - "tool_calls": [ - { - "id": "call_5306", - "type": "function", - "function": { - "name": "weather_forecast", - "arguments": { "city": "Chicago", "days": 5 } - } - } - ] - } - ] - """ + {"role": "tool", "content": "--", "tool_call_id": "call_5306"}, +] def test_update_fc_history(): - history = json.loads(test_input_history) message_history = [] - for h in history: + + for h in test_input_history: message_history.append(Message(**h)) - updated_history = process_messages(message_history) - assert len(updated_history) == 6 + updated_history = handler_map["Arch-Function"]._process_messages(message_history) + assert len(updated_history) == 7 # ensure that tool role does not exist anymore assert all([h["role"] != "tool" for h in updated_history])