diff --git a/model_server/app/model_handler/function_calling.py b/model_server/app/model_handler/function_calling.py
index 109a882f..167f52ff 100644
--- a/model_server/app/model_handler/function_calling.py
+++ b/model_server/app/model_handler/function_calling.py
@@ -12,6 +12,7 @@ from app.model_handler.base_handler import (
ChatCompletionResponse,
ArchBaseHandler,
)
+from app.function_calling.hallucination_handler import HallucinationStateHandler
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
@@ -342,6 +343,36 @@ class ArchFunctionHandler(ArchBaseHandler):
return is_valid, error_tool_call, error_message
+ def _prefill_response(self, messages: List[Dict[str, str]]):
+ """
+ Prefills the response with the tool call prefix.
+
+ Args:
+ messages (List[Dict[str, str]]): A list of messages.
+ tools (List[Dict[str, Any]]): A list of tools.
+
+ Returns:
+ List[Dict[str, str]]: A list of messages with the prefill prefix.
+ """
+
+ messages.append(
+ {
+ "role": "assistant",
+ "content": random.choice(self.prefill_prefix),
+ }
+ )
+ prefill_response = self.client.chat.completions.create(
+ messages=messages,
+ model=self.model_name,
+ stream=False,
+ extra_body={
+ **self.generation_params,
+ **self.prefill_params,
+ },
+ )
+
+ return prefill_response
+
@override
async def chat_completion(
self, req: ChatMessage, enable_prefilling=True
@@ -392,23 +423,7 @@ class ArchFunctionHandler(ArchBaseHandler):
# start parameter gathering if the model is not generating a tool call
if has_tool_call is False:
- messages.append(
- {
- "role": "assistant",
- "content": random.choice(self.prefill_prefix),
- }
- )
-
- prefill_response = self.client.chat.completions.create(
- messages=messages,
- model=self.model_name,
- stream=False,
- extra_body={
- **self.generation_params,
- **self.prefill_params,
- },
- )
-
+ prefill_response = self._prefill_response(messages)
model_response = prefill_response.choices[0].message.content
else:
model_response = response.choices[0].message.content
diff --git a/model_server/app/model_handler/hallucination_handler.py b/model_server/app/model_handler/hallucination_handler.py
index 09607db3..7353312a 100644
--- a/model_server/app/model_handler/hallucination_handler.py
+++ b/model_server/app/model_handler/hallucination_handler.py
@@ -72,6 +72,25 @@ def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
return entropy.item(), varentropy.item()
+def is_parameter_required(
+ function_description: Dict,
+ parameter_name: str,
+) -> bool:
+ """
+ Check if a parameter in required list
+
+ Args:
+ function_description (dict): The API description in JSON format.
+ parameter_name (str): The name of the parameter to check.
+
+ Returns:
+ bool: True if the parameter has the specified property, False otherwise.
+ """
+ required_parameters = function_description.get("required", {})
+
+ return parameter_name in required_parameters
+
+
class HallucinationStateHandler:
"""
A class to handle the state of hallucination detection in token processing.
@@ -104,6 +123,7 @@ class HallucinationStateHandler:
self.parameter_name: List[str] = []
self.token_probs_map: List[Tuple[str, float, float]] = []
self.response_iterator = response_iterator
+ self.has_tool_call = False
def append_and_check_token_hallucination(self, token, logprob):
"""
@@ -118,7 +138,8 @@ class HallucinationStateHandler:
"""
self.tokens.append(token)
self.logprobs.append(logprob)
- self._process_token()
+ if self.has_tool_call:
+ self._process_token()
return self.hallucination
def __iter__(self):
@@ -164,7 +185,7 @@ class HallucinationStateHandler:
self.mask.append(MaskToken.FUNCTION_NAME)
else:
self.state = None
- self._is_function_name_hallucinated()
+ self._get_function_name()
# Check if the token is a function name start token, change the state
if content.endswith(FUNC_NAME_START_PATTERN):
@@ -182,8 +203,8 @@ class HallucinationStateHandler:
PARAMETER_NAME_END_TOKENS
):
self.state = None
- self._is_parameter_name_hallucinated()
self.parameter_name_done = True
+ self._get_parameter_name()
# if the parameter name is done and the token is a parameter name start token, change the state
elif self.parameter_name_done and content.endswith(
PARAMETER_NAME_START_PATTERN
@@ -208,11 +229,10 @@ class HallucinationStateHandler:
if (
len(self.mask) > 1
and self.mask[-2] != MaskToken.PARAMETER_VALUE
- # and not is_parameter_property(
- # self.function_properties[self.function_name],
- # self.parameter_name[-1],
- # "default",
- # )
+ and is_parameter_required(
+ self.function_properties[self.function_name],
+ self.parameter_name[-1],
+ )
):
self._check_logprob()
else:
@@ -266,3 +286,24 @@ class HallucinationStateHandler:
if self.mask and self.mask[-1] == token
else 0
)
+
+ def _get_parameter_name(self):
+ """
+ Get the parameter name from the tokens.
+
+ Returns:
+ str: The extracted parameter name.
+ """
+ p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME)
+ parameter_name = "".join(self.tokens[:-1][-p_len:])
+ self.parameter_name.append(parameter_name)
+
+ def _get_function_name(self):
+ """
+ Get the function name from the tokens.
+
+ Returns:
+ str: The extracted function name.
+ """
+ f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
+ self.function_name = "".join(self.tokens[:-1][-f_len:])
diff --git a/model_server/app/tests/test.ipynb b/model_server/app/tests/test.ipynb
new file mode 100644
index 00000000..1dfd455e
--- /dev/null
+++ b/model_server/app/tests/test.ipynb
@@ -0,0 +1,513 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/co-tran/Documents/arch/new_venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model: Arch-Function\n"
+ ]
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import json\n",
+ "import ast\n",
+ "from datasets import load_dataset\n",
+ "import re\n",
+ "import subprocess\n",
+ "import threading\n",
+ "import requests\n",
+ "import time\n",
+ "from concurrent.futures import ThreadPoolExecutor\n",
+ "from tqdm import tqdm\n",
+ "import os\n",
+ "import json\n",
+ "import math\n",
+ "import torch\n",
+ "\n",
+ "from datasets import load_dataset\n",
+ "from typing import Any, Dict, List, Tuple\n",
+ "\n",
+ "from typing import Any, Dict, List\n",
+ "\n",
+ "def extract_tool_calls(content: str):\n",
+ " tool_calls = []\n",
+ "\n",
+ " flag = False\n",
+ " for line in content.split(\"\\n\"):\n",
+ " if \"\" == line:\n",
+ " flag = True\n",
+ " elif \"\" == line:\n",
+ " flag = False\n",
+ " else:\n",
+ " if flag:\n",
+ " try:\n",
+ " tool_content = json.loads(line)\n",
+ " except Exception as e:\n",
+ " print(e)\n",
+ " fixed_content = fix_json_string(line)\n",
+ " try:\n",
+ " tool_content = json.loads(fixed_content)\n",
+ " except json.JSONDecodeError:\n",
+ " print(\"json.JSONDecodeError\")\n",
+ " return content\n",
+ "\n",
+ " tool_calls.append(\n",
+ " {\n",
+ " \"id\": f\"call_{random.randint(1000, 10000)}\",\n",
+ " \"type\": \"function\",\n",
+ " \"function\": {\n",
+ " \"name\": tool_content[\"name\"],\n",
+ " \"arguments\": tool_content[\"arguments\"],\n",
+ " },\n",
+ " }\n",
+ " )\n",
+ "\n",
+ " flag = False\n",
+ "\n",
+ " return tool_calls\n",
+ "\n",
+ "def fix_json_string(json_str: str):\n",
+ " # Remove any leading or trailing whitespace or newline characters\n",
+ " json_str = json_str.strip()\n",
+ "\n",
+ " # Stack to keep track of brackets\n",
+ " stack = []\n",
+ "\n",
+ " # Clean string to collect valid characters\n",
+ " fixed_str = \"\"\n",
+ "\n",
+ " # Dictionary for matching brackets\n",
+ " matching_bracket = {\")\": \"(\", \"}\": \"{\", \"]\": \"[\"}\n",
+ "\n",
+ " # Dictionary for the opposite of matching_bracket\n",
+ " opening_bracket = {v: k for k, v in matching_bracket.items()}\n",
+ "\n",
+ " for char in json_str:\n",
+ " if char in \"{[(\":\n",
+ " stack.append(char)\n",
+ " fixed_str += char\n",
+ " elif char in \"}])\":\n",
+ " if stack and stack[-1] == matching_bracket[char]:\n",
+ " stack.pop()\n",
+ " fixed_str += char\n",
+ " else:\n",
+ " # Ignore the unmatched closing brackets\n",
+ " continue\n",
+ " else:\n",
+ " fixed_str += char\n",
+ "\n",
+ " # If there are unmatched opening brackets left in the stack, add corresponding closing brackets\n",
+ " while stack:\n",
+ " unmatched_opening = stack.pop()\n",
+ " fixed_str += opening_bracket[unmatched_opening]\n",
+ "\n",
+ " # Attempt to parse the corrected string to ensure it’s valid JSON\n",
+ " return fixed_str.replace(\"\\'\", \"\\\"\")\n",
+ "\n",
+ "\n",
+ "TASK_PROMPT = \"\"\"\n",
+ "You are a helpful assistant.\n",
+ "\"\"\".strip()\n",
+ "\n",
+ "TOOL_PROMPT = \"\"\"\n",
+ "# Tools\n",
+ "\n",
+ "You may call one or more functions to assist with the user query.\n",
+ "\n",
+ "You are provided with function signatures within XML tags:\n",
+ "\n",
+ "{tool_text}\n",
+ "\n",
+ "\"\"\".strip()\n",
+ "\n",
+ "FORMAT_PROMPT = \"\"\"\n",
+ "For each function call, return a json object with function name and arguments within XML tags:\n",
+ "\n",
+ "{\"name\": , \"arguments\": }\n",
+ "\n",
+ "\"\"\".strip()\n",
+ "\n",
+ "\n",
+ "get_weather_api = {\n",
+ "\t\"type\": \"function\",\n",
+ "\t\"function\": {\n",
+ "\t\t\"name\": \"get_current_weather\",\n",
+ "\t\t\"description\": \"Get current weather at a location.\",\n",
+ "\t\t\"parameters\": {\n",
+ "\t\t\t\"type\": \"object\",\n",
+ "\t\t\t\"properties\": {\n",
+ "\t\t\t\t\"location\": {\n",
+ "\t\t\t\t\t\"type\": \"str\",\n",
+ "\t\t\t\t\t\"description\": \"The location to get the weather for\",\n",
+ " \"format\": \"City, State, Country\",\n",
+ "\t\t\t\t},\n",
+ "\t\t\t\t\"unit\": {\n",
+ "\t\t\t\t\t\"type\": \"str\",\n",
+ " \"description\": \"The unit to return the weather in.\",\n",
+ "\t\t\t\t\t\"enum\": [\n",
+ "\t\t\t\t\t\t\"celsius\",\n",
+ "\t\t\t\t\t\t\"fahrenheit\"\n",
+ "\t\t\t\t\t],\n",
+ " \"default\": \"celsius\"\n",
+ "\t\t\t\t},\n",
+ " \"days\": {\n",
+ "\t\t\t\t\t\"type\": \"str\",\n",
+ " \"description\": \"the number of days for the request.\",\n",
+ "\t\t\t\t}\n",
+ "\t\t\t},\n",
+ "\t\t\t\"required\": [\n",
+ "\t\t\t\t\"location\",\n",
+ " \"days\"\n",
+ "\t\t\t]\n",
+ "\t\t}\n",
+ "\t}\n",
+ "}\n",
+ "def check_parameter_property(api_description, parameter_name, property_name):\n",
+ " \"\"\"\n",
+ " Check if a parameter in an API description has a specific property.\n",
+ "\n",
+ " Args:\n",
+ " api_description (dict): The API description in JSON format.\n",
+ " parameter_name (str): The name of the parameter to check.\n",
+ " property_name (str): The property to look for (e.g., 'format', 'default').\n",
+ "\n",
+ " Returns:\n",
+ " bool: True if the parameter has the specified property, False otherwise.\n",
+ " \"\"\"\n",
+ " parameters = api_description.get(\"parameters\", {}).get(\"properties\", {})\n",
+ " parameter_info = parameters.get(parameter_name, {})\n",
+ "\n",
+ " return property_name in parameter_info\n",
+ "\n",
+ "\n",
+ "# Example usage\n",
+ "\n",
+ "def convert_tools(tools: List[Dict[str, Any]]):\n",
+ " return \"\\n\".join([json.dumps(tool) for tool in tools])\n",
+ "\n",
+ "# Helper function to create the system prompt for our model\n",
+ "def format_prompt(tools: List[Dict[str, Any]]):\n",
+ " tool_text = convert_tools(tools)\n",
+ "\n",
+ " return (\n",
+ " TASK_PROMPT\n",
+ " + \"\\n\\n\"\n",
+ " + TOOL_PROMPT.format(tool_text=tool_text)\n",
+ " + \"\\n\\n\"\n",
+ " + FORMAT_PROMPT\n",
+ " + \"\\n\"\n",
+ " )\n",
+ "\n",
+ "openai_format_tools = [get_weather_api]\n",
+ "\n",
+ "system_prompt = format_prompt(openai_format_tools)\n",
+ "\n",
+ "\n",
+ "from openai import OpenAI\n",
+ "\n",
+ "client = OpenAI(base_url=\"https://api.fc.archgw.com/v1\", api_key=\"EMPTY\")\n",
+ "\n",
+ "# List models API\n",
+ "model = client.models.list().data[0].id\n",
+ "print(\"Model:\", model)\n",
+ "messages = [\n",
+ " {\"role\": \"system\", \"content\": system_prompt},\n",
+ " # {\"role\": \"user\", \"content\": \"can you help me check weather?\"},\n",
+ " {\"role\": \"user\", \"content\": \"How is the weather in Seattle in 7 days?\"},\n",
+ " #{\"role\": \"assistant\", \"content\": \"Of course!\"},\n",
+ " # {\"role\": \"user\", \"content\": \"Seattle please\"}\n",
+ "]\n",
+ "\n",
+ "extra_body = {\n",
+ " \"temperature\": 0.6,\n",
+ " \"top_p\": 1.0,\n",
+ " \"top_k\": 50,\n",
+ " # \"continue_final_message\": True,\n",
+ " # \"add_generation_prompt\": False,\n",
+ " \"logprobs\": True,\n",
+ " \"top_logprobs\": 10\n",
+ "}\n",
+ "\n",
+ "resp = client.chat.completions.create(\n",
+ " model=\"Arch-Function\",\n",
+ " messages=messages,\n",
+ " extra_body=extra_body,\n",
+ " stream = True\n",
+ ")\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n",
+ "The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n"
+ ]
+ }
+ ],
+ "source": [
+ "import json\n",
+ "from app.function_calling.hallucination_handler import HallucinationStateHandler\n",
+ "import pytest\n",
+ "import os"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "get_weather_api = {\n",
+ " \"type\": \"function\",\n",
+ " \"function\": {\n",
+ " \"name\": \"get_current_weather\",\n",
+ " \"description\": \"Get current weather at a location.\",\n",
+ " \"parameters\": {\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"location\": {\n",
+ " \"type\": \"str\",\n",
+ " \"description\": \"The location to get the weather for\",\n",
+ " \"format\": \"City, State\",\n",
+ " },\n",
+ " \"unit\": {\n",
+ " \"type\": \"str\",\n",
+ " \"description\": \"The unit to return the weather in.\",\n",
+ " \"enum\": [\"celsius\", \"fahrenheit\"],\n",
+ " \"default\": \"celsius\",\n",
+ " },\n",
+ " \"days\": {\n",
+ " \"type\": \"str\",\n",
+ " \"description\": \"the number of days for the request.\",\n",
+ " },\n",
+ " },\n",
+ " \"required\": [\"location\", \"days\"],\n",
+ " },\n",
+ " },\n",
+ "}\n",
+ "function_description = get_weather_api[\"function\"]\n",
+ "if type(function_description) != list:\n",
+ " function_description = [get_weather_api[\"function\"]]\n",
+ "\n",
+ "hallu = HallucinationStateHandler(response_iterator=resp, apis = function_description)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "False\n"
+ ]
+ }
+ ],
+ "source": [
+ "r = next(hallu.response_iterator)\n",
+ "if hasattr(r.choices[0].delta, \"content\"):\n",
+ " token_content = r.choices[0].delta.content\n",
+ " if token_content:\n",
+ " logprobs = [p.logprob for p in r.choices[0].logprobs.content[0].top_logprobs]\n",
+ " print(hallu.check_token_hallucination(token_content, logprobs), hallu.tokens)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Token: None\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \n",
+ "\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: {\"\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: name\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \":\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "function name entered\n",
+ "Token: \"\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: get\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: _current\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: _weather\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \",\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \"\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: arguments\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \":\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: {\"\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: location\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \":\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \"\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: Seattle\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: ,\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: WA\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: ,\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: USA\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \",\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \"\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: days\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \":\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \"\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: 7\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \"}}\n",
+ "\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: \n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n",
+ "Token: None\n",
+ "Hallucination: False\n",
+ "Hallucination Message: \n"
+ ]
+ }
+ ],
+ "source": [
+ "for token in hallu:\n",
+ " print(f\"Token: {token}\")\n",
+ " print(f\"Hallucination: {hallu.hallucination}\")\n",
+ " print(f\"Hallucination Message: {hallu.hallucination_message}\")\n",
+ "\n",
+ "# Access the tokens processed"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for token in resp:\n",
+ " if hasattr(token.choices[0].delta, \"content\"):\n",
+ " \n",
+ " token_content = token.choices[0].delta.content\n",
+ " print(token_content)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[('', 0.0034013802651315928, -0.0032856864854693413),\n",
+ " ('Seattle', 1.1522446584422141e-05, -1.1521117812662851e-05),\n",
+ " ('7', 1.553274159959983e-05, -1.5530327800661325e-05)]"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "hallu.token_probs_map"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}