diff --git a/model_server/app/model_handler/function_calling.py b/model_server/app/model_handler/function_calling.py index 109a882f..167f52ff 100644 --- a/model_server/app/model_handler/function_calling.py +++ b/model_server/app/model_handler/function_calling.py @@ -12,6 +12,7 @@ from app.model_handler.base_handler import ( ChatCompletionResponse, ArchBaseHandler, ) +from app.function_calling.hallucination_handler import HallucinationStateHandler SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"] @@ -342,6 +343,36 @@ class ArchFunctionHandler(ArchBaseHandler): return is_valid, error_tool_call, error_message + def _prefill_response(self, messages: List[Dict[str, str]]): + """ + Prefills the response with the tool call prefix. + + Args: + messages (List[Dict[str, str]]): A list of messages. + tools (List[Dict[str, Any]]): A list of tools. + + Returns: + List[Dict[str, str]]: A list of messages with the prefill prefix. + """ + + messages.append( + { + "role": "assistant", + "content": random.choice(self.prefill_prefix), + } + ) + prefill_response = self.client.chat.completions.create( + messages=messages, + model=self.model_name, + stream=False, + extra_body={ + **self.generation_params, + **self.prefill_params, + }, + ) + + return prefill_response + @override async def chat_completion( self, req: ChatMessage, enable_prefilling=True @@ -392,23 +423,7 @@ class ArchFunctionHandler(ArchBaseHandler): # start parameter gathering if the model is not generating a tool call if has_tool_call is False: - messages.append( - { - "role": "assistant", - "content": random.choice(self.prefill_prefix), - } - ) - - prefill_response = self.client.chat.completions.create( - messages=messages, - model=self.model_name, - stream=False, - extra_body={ - **self.generation_params, - **self.prefill_params, - }, - ) - + prefill_response = self._prefill_response(messages) model_response = prefill_response.choices[0].message.content else: model_response = response.choices[0].message.content diff --git a/model_server/app/model_handler/hallucination_handler.py b/model_server/app/model_handler/hallucination_handler.py index 09607db3..7353312a 100644 --- a/model_server/app/model_handler/hallucination_handler.py +++ b/model_server/app/model_handler/hallucination_handler.py @@ -72,6 +72,25 @@ def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]: return entropy.item(), varentropy.item() +def is_parameter_required( + function_description: Dict, + parameter_name: str, +) -> bool: + """ + Check if a parameter in required list + + Args: + function_description (dict): The API description in JSON format. + parameter_name (str): The name of the parameter to check. + + Returns: + bool: True if the parameter has the specified property, False otherwise. + """ + required_parameters = function_description.get("required", {}) + + return parameter_name in required_parameters + + class HallucinationStateHandler: """ A class to handle the state of hallucination detection in token processing. @@ -104,6 +123,7 @@ class HallucinationStateHandler: self.parameter_name: List[str] = [] self.token_probs_map: List[Tuple[str, float, float]] = [] self.response_iterator = response_iterator + self.has_tool_call = False def append_and_check_token_hallucination(self, token, logprob): """ @@ -118,7 +138,8 @@ class HallucinationStateHandler: """ self.tokens.append(token) self.logprobs.append(logprob) - self._process_token() + if self.has_tool_call: + self._process_token() return self.hallucination def __iter__(self): @@ -164,7 +185,7 @@ class HallucinationStateHandler: self.mask.append(MaskToken.FUNCTION_NAME) else: self.state = None - self._is_function_name_hallucinated() + self._get_function_name() # Check if the token is a function name start token, change the state if content.endswith(FUNC_NAME_START_PATTERN): @@ -182,8 +203,8 @@ class HallucinationStateHandler: PARAMETER_NAME_END_TOKENS ): self.state = None - self._is_parameter_name_hallucinated() self.parameter_name_done = True + self._get_parameter_name() # if the parameter name is done and the token is a parameter name start token, change the state elif self.parameter_name_done and content.endswith( PARAMETER_NAME_START_PATTERN @@ -208,11 +229,10 @@ class HallucinationStateHandler: if ( len(self.mask) > 1 and self.mask[-2] != MaskToken.PARAMETER_VALUE - # and not is_parameter_property( - # self.function_properties[self.function_name], - # self.parameter_name[-1], - # "default", - # ) + and is_parameter_required( + self.function_properties[self.function_name], + self.parameter_name[-1], + ) ): self._check_logprob() else: @@ -266,3 +286,24 @@ class HallucinationStateHandler: if self.mask and self.mask[-1] == token else 0 ) + + def _get_parameter_name(self): + """ + Get the parameter name from the tokens. + + Returns: + str: The extracted parameter name. + """ + p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME) + parameter_name = "".join(self.tokens[:-1][-p_len:]) + self.parameter_name.append(parameter_name) + + def _get_function_name(self): + """ + Get the function name from the tokens. + + Returns: + str: The extracted function name. + """ + f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME) + self.function_name = "".join(self.tokens[:-1][-f_len:]) diff --git a/model_server/app/tests/test.ipynb b/model_server/app/tests/test.ipynb new file mode 100644 index 00000000..1dfd455e --- /dev/null +++ b/model_server/app/tests/test.ipynb @@ -0,0 +1,513 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/co-tran/Documents/arch/new_venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: Arch-Function\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import json\n", + "import ast\n", + "from datasets import load_dataset\n", + "import re\n", + "import subprocess\n", + "import threading\n", + "import requests\n", + "import time\n", + "from concurrent.futures import ThreadPoolExecutor\n", + "from tqdm import tqdm\n", + "import os\n", + "import json\n", + "import math\n", + "import torch\n", + "\n", + "from datasets import load_dataset\n", + "from typing import Any, Dict, List, Tuple\n", + "\n", + "from typing import Any, Dict, List\n", + "\n", + "def extract_tool_calls(content: str):\n", + " tool_calls = []\n", + "\n", + " flag = False\n", + " for line in content.split(\"\\n\"):\n", + " if \"\" == line:\n", + " flag = True\n", + " elif \"\" == line:\n", + " flag = False\n", + " else:\n", + " if flag:\n", + " try:\n", + " tool_content = json.loads(line)\n", + " except Exception as e:\n", + " print(e)\n", + " fixed_content = fix_json_string(line)\n", + " try:\n", + " tool_content = json.loads(fixed_content)\n", + " except json.JSONDecodeError:\n", + " print(\"json.JSONDecodeError\")\n", + " return content\n", + "\n", + " tool_calls.append(\n", + " {\n", + " \"id\": f\"call_{random.randint(1000, 10000)}\",\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": tool_content[\"name\"],\n", + " \"arguments\": tool_content[\"arguments\"],\n", + " },\n", + " }\n", + " )\n", + "\n", + " flag = False\n", + "\n", + " return tool_calls\n", + "\n", + "def fix_json_string(json_str: str):\n", + " # Remove any leading or trailing whitespace or newline characters\n", + " json_str = json_str.strip()\n", + "\n", + " # Stack to keep track of brackets\n", + " stack = []\n", + "\n", + " # Clean string to collect valid characters\n", + " fixed_str = \"\"\n", + "\n", + " # Dictionary for matching brackets\n", + " matching_bracket = {\")\": \"(\", \"}\": \"{\", \"]\": \"[\"}\n", + "\n", + " # Dictionary for the opposite of matching_bracket\n", + " opening_bracket = {v: k for k, v in matching_bracket.items()}\n", + "\n", + " for char in json_str:\n", + " if char in \"{[(\":\n", + " stack.append(char)\n", + " fixed_str += char\n", + " elif char in \"}])\":\n", + " if stack and stack[-1] == matching_bracket[char]:\n", + " stack.pop()\n", + " fixed_str += char\n", + " else:\n", + " # Ignore the unmatched closing brackets\n", + " continue\n", + " else:\n", + " fixed_str += char\n", + "\n", + " # If there are unmatched opening brackets left in the stack, add corresponding closing brackets\n", + " while stack:\n", + " unmatched_opening = stack.pop()\n", + " fixed_str += opening_bracket[unmatched_opening]\n", + "\n", + " # Attempt to parse the corrected string to ensure it’s valid JSON\n", + " return fixed_str.replace(\"\\'\", \"\\\"\")\n", + "\n", + "\n", + "TASK_PROMPT = \"\"\"\n", + "You are a helpful assistant.\n", + "\"\"\".strip()\n", + "\n", + "TOOL_PROMPT = \"\"\"\n", + "# Tools\n", + "\n", + "You may call one or more functions to assist with the user query.\n", + "\n", + "You are provided with function signatures within XML tags:\n", + "\n", + "{tool_text}\n", + "\n", + "\"\"\".strip()\n", + "\n", + "FORMAT_PROMPT = \"\"\"\n", + "For each function call, return a json object with function name and arguments within XML tags:\n", + "\n", + "{\"name\": , \"arguments\": }\n", + "\n", + "\"\"\".strip()\n", + "\n", + "\n", + "get_weather_api = {\n", + "\t\"type\": \"function\",\n", + "\t\"function\": {\n", + "\t\t\"name\": \"get_current_weather\",\n", + "\t\t\"description\": \"Get current weather at a location.\",\n", + "\t\t\"parameters\": {\n", + "\t\t\t\"type\": \"object\",\n", + "\t\t\t\"properties\": {\n", + "\t\t\t\t\"location\": {\n", + "\t\t\t\t\t\"type\": \"str\",\n", + "\t\t\t\t\t\"description\": \"The location to get the weather for\",\n", + " \"format\": \"City, State, Country\",\n", + "\t\t\t\t},\n", + "\t\t\t\t\"unit\": {\n", + "\t\t\t\t\t\"type\": \"str\",\n", + " \"description\": \"The unit to return the weather in.\",\n", + "\t\t\t\t\t\"enum\": [\n", + "\t\t\t\t\t\t\"celsius\",\n", + "\t\t\t\t\t\t\"fahrenheit\"\n", + "\t\t\t\t\t],\n", + " \"default\": \"celsius\"\n", + "\t\t\t\t},\n", + " \"days\": {\n", + "\t\t\t\t\t\"type\": \"str\",\n", + " \"description\": \"the number of days for the request.\",\n", + "\t\t\t\t}\n", + "\t\t\t},\n", + "\t\t\t\"required\": [\n", + "\t\t\t\t\"location\",\n", + " \"days\"\n", + "\t\t\t]\n", + "\t\t}\n", + "\t}\n", + "}\n", + "def check_parameter_property(api_description, parameter_name, property_name):\n", + " \"\"\"\n", + " Check if a parameter in an API description has a specific property.\n", + "\n", + " Args:\n", + " api_description (dict): The API description in JSON format.\n", + " parameter_name (str): The name of the parameter to check.\n", + " property_name (str): The property to look for (e.g., 'format', 'default').\n", + "\n", + " Returns:\n", + " bool: True if the parameter has the specified property, False otherwise.\n", + " \"\"\"\n", + " parameters = api_description.get(\"parameters\", {}).get(\"properties\", {})\n", + " parameter_info = parameters.get(parameter_name, {})\n", + "\n", + " return property_name in parameter_info\n", + "\n", + "\n", + "# Example usage\n", + "\n", + "def convert_tools(tools: List[Dict[str, Any]]):\n", + " return \"\\n\".join([json.dumps(tool) for tool in tools])\n", + "\n", + "# Helper function to create the system prompt for our model\n", + "def format_prompt(tools: List[Dict[str, Any]]):\n", + " tool_text = convert_tools(tools)\n", + "\n", + " return (\n", + " TASK_PROMPT\n", + " + \"\\n\\n\"\n", + " + TOOL_PROMPT.format(tool_text=tool_text)\n", + " + \"\\n\\n\"\n", + " + FORMAT_PROMPT\n", + " + \"\\n\"\n", + " )\n", + "\n", + "openai_format_tools = [get_weather_api]\n", + "\n", + "system_prompt = format_prompt(openai_format_tools)\n", + "\n", + "\n", + "from openai import OpenAI\n", + "\n", + "client = OpenAI(base_url=\"https://api.fc.archgw.com/v1\", api_key=\"EMPTY\")\n", + "\n", + "# List models API\n", + "model = client.models.list().data[0].id\n", + "print(\"Model:\", model)\n", + "messages = [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " # {\"role\": \"user\", \"content\": \"can you help me check weather?\"},\n", + " {\"role\": \"user\", \"content\": \"How is the weather in Seattle in 7 days?\"},\n", + " #{\"role\": \"assistant\", \"content\": \"Of course!\"},\n", + " # {\"role\": \"user\", \"content\": \"Seattle please\"}\n", + "]\n", + "\n", + "extra_body = {\n", + " \"temperature\": 0.6,\n", + " \"top_p\": 1.0,\n", + " \"top_k\": 50,\n", + " # \"continue_final_message\": True,\n", + " # \"add_generation_prompt\": False,\n", + " \"logprobs\": True,\n", + " \"top_logprobs\": 10\n", + "}\n", + "\n", + "resp = client.chat.completions.create(\n", + " model=\"Arch-Function\",\n", + " messages=messages,\n", + " extra_body=extra_body,\n", + " stream = True\n", + ")\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n", + "The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n" + ] + } + ], + "source": [ + "import json\n", + "from app.function_calling.hallucination_handler import HallucinationStateHandler\n", + "import pytest\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "get_weather_api = {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_current_weather\",\n", + " \"description\": \"Get current weather at a location.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"str\",\n", + " \"description\": \"The location to get the weather for\",\n", + " \"format\": \"City, State\",\n", + " },\n", + " \"unit\": {\n", + " \"type\": \"str\",\n", + " \"description\": \"The unit to return the weather in.\",\n", + " \"enum\": [\"celsius\", \"fahrenheit\"],\n", + " \"default\": \"celsius\",\n", + " },\n", + " \"days\": {\n", + " \"type\": \"str\",\n", + " \"description\": \"the number of days for the request.\",\n", + " },\n", + " },\n", + " \"required\": [\"location\", \"days\"],\n", + " },\n", + " },\n", + "}\n", + "function_description = get_weather_api[\"function\"]\n", + "if type(function_description) != list:\n", + " function_description = [get_weather_api[\"function\"]]\n", + "\n", + "hallu = HallucinationStateHandler(response_iterator=resp, apis = function_description)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "False\n" + ] + } + ], + "source": [ + "r = next(hallu.response_iterator)\n", + "if hasattr(r.choices[0].delta, \"content\"):\n", + " token_content = r.choices[0].delta.content\n", + " if token_content:\n", + " logprobs = [p.logprob for p in r.choices[0].logprobs.content[0].top_logprobs]\n", + " print(hallu.check_token_hallucination(token_content, logprobs), hallu.tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Token: None\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \n", + "\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: {\"\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: name\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \":\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "function name entered\n", + "Token: \"\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: get\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: _current\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: _weather\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \",\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \"\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: arguments\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \":\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: {\"\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: location\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \":\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \"\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: Seattle\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: ,\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: WA\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: ,\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: USA\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \",\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \"\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: days\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \":\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \"\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: 7\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \"}}\n", + "\n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: \n", + "Hallucination: False\n", + "Hallucination Message: \n", + "Token: None\n", + "Hallucination: False\n", + "Hallucination Message: \n" + ] + } + ], + "source": [ + "for token in hallu:\n", + " print(f\"Token: {token}\")\n", + " print(f\"Hallucination: {hallu.hallucination}\")\n", + " print(f\"Hallucination Message: {hallu.hallucination_message}\")\n", + "\n", + "# Access the tokens processed" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "for token in resp:\n", + " if hasattr(token.choices[0].delta, \"content\"):\n", + " \n", + " token_content = token.choices[0].delta.content\n", + " print(token_content)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('', 0.0034013802651315928, -0.0032856864854693413),\n", + " ('Seattle', 1.1522446584422141e-05, -1.1521117812662851e-05),\n", + " ('7', 1.553274159959983e-05, -1.5530327800661325e-05)]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hallu.token_probs_map" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}