integrate hallucination

This commit is contained in:
cotran 2024-12-06 15:50:03 -08:00
parent f7d69d52a7
commit 5e164e8e3c
3 changed files with 594 additions and 25 deletions

View file

@ -12,6 +12,7 @@ from app.model_handler.base_handler import (
ChatCompletionResponse,
ArchBaseHandler,
)
from app.function_calling.hallucination_handler import HallucinationStateHandler
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
@ -342,6 +343,36 @@ class ArchFunctionHandler(ArchBaseHandler):
return is_valid, error_tool_call, error_message
def _prefill_response(self, messages: List[Dict[str, str]]):
"""
Prefills the response with the tool call prefix.
Args:
messages (List[Dict[str, str]]): A list of messages.
tools (List[Dict[str, Any]]): A list of tools.
Returns:
List[Dict[str, str]]: A list of messages with the prefill prefix.
"""
messages.append(
{
"role": "assistant",
"content": random.choice(self.prefill_prefix),
}
)
prefill_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
return prefill_response
@override
async def chat_completion(
self, req: ChatMessage, enable_prefilling=True
@ -392,23 +423,7 @@ class ArchFunctionHandler(ArchBaseHandler):
# start parameter gathering if the model is not generating a tool call
if has_tool_call is False:
messages.append(
{
"role": "assistant",
"content": random.choice(self.prefill_prefix),
}
)
prefill_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
prefill_response = self._prefill_response(messages)
model_response = prefill_response.choices[0].message.content
else:
model_response = response.choices[0].message.content

View file

@ -72,6 +72,25 @@ def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
return entropy.item(), varentropy.item()
def is_parameter_required(
function_description: Dict,
parameter_name: str,
) -> bool:
"""
Check if a parameter in required list
Args:
function_description (dict): The API description in JSON format.
parameter_name (str): The name of the parameter to check.
Returns:
bool: True if the parameter has the specified property, False otherwise.
"""
required_parameters = function_description.get("required", {})
return parameter_name in required_parameters
class HallucinationStateHandler:
"""
A class to handle the state of hallucination detection in token processing.
@ -104,6 +123,7 @@ class HallucinationStateHandler:
self.parameter_name: List[str] = []
self.token_probs_map: List[Tuple[str, float, float]] = []
self.response_iterator = response_iterator
self.has_tool_call = False
def append_and_check_token_hallucination(self, token, logprob):
"""
@ -118,7 +138,8 @@ class HallucinationStateHandler:
"""
self.tokens.append(token)
self.logprobs.append(logprob)
self._process_token()
if self.has_tool_call:
self._process_token()
return self.hallucination
def __iter__(self):
@ -164,7 +185,7 @@ class HallucinationStateHandler:
self.mask.append(MaskToken.FUNCTION_NAME)
else:
self.state = None
self._is_function_name_hallucinated()
self._get_function_name()
# Check if the token is a function name start token, change the state
if content.endswith(FUNC_NAME_START_PATTERN):
@ -182,8 +203,8 @@ class HallucinationStateHandler:
PARAMETER_NAME_END_TOKENS
):
self.state = None
self._is_parameter_name_hallucinated()
self.parameter_name_done = True
self._get_parameter_name()
# if the parameter name is done and the token is a parameter name start token, change the state
elif self.parameter_name_done and content.endswith(
PARAMETER_NAME_START_PATTERN
@ -208,11 +229,10 @@ class HallucinationStateHandler:
if (
len(self.mask) > 1
and self.mask[-2] != MaskToken.PARAMETER_VALUE
# and not is_parameter_property(
# self.function_properties[self.function_name],
# self.parameter_name[-1],
# "default",
# )
and is_parameter_required(
self.function_properties[self.function_name],
self.parameter_name[-1],
)
):
self._check_logprob()
else:
@ -266,3 +286,24 @@ class HallucinationStateHandler:
if self.mask and self.mask[-1] == token
else 0
)
def _get_parameter_name(self):
"""
Get the parameter name from the tokens.
Returns:
str: The extracted parameter name.
"""
p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME)
parameter_name = "".join(self.tokens[:-1][-p_len:])
self.parameter_name.append(parameter_name)
def _get_function_name(self):
"""
Get the function name from the tokens.
Returns:
str: The extracted function name.
"""
f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
self.function_name = "".join(self.tokens[:-1][-f_len:])

View file

@ -0,0 +1,513 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/co-tran/Documents/arch/new_venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: Arch-Function\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import json\n",
"import ast\n",
"from datasets import load_dataset\n",
"import re\n",
"import subprocess\n",
"import threading\n",
"import requests\n",
"import time\n",
"from concurrent.futures import ThreadPoolExecutor\n",
"from tqdm import tqdm\n",
"import os\n",
"import json\n",
"import math\n",
"import torch\n",
"\n",
"from datasets import load_dataset\n",
"from typing import Any, Dict, List, Tuple\n",
"\n",
"from typing import Any, Dict, List\n",
"\n",
"def extract_tool_calls(content: str):\n",
" tool_calls = []\n",
"\n",
" flag = False\n",
" for line in content.split(\"\\n\"):\n",
" if \"<tool_call>\" == line:\n",
" flag = True\n",
" elif \"</tool_call>\" == line:\n",
" flag = False\n",
" else:\n",
" if flag:\n",
" try:\n",
" tool_content = json.loads(line)\n",
" except Exception as e:\n",
" print(e)\n",
" fixed_content = fix_json_string(line)\n",
" try:\n",
" tool_content = json.loads(fixed_content)\n",
" except json.JSONDecodeError:\n",
" print(\"json.JSONDecodeError\")\n",
" return content\n",
"\n",
" tool_calls.append(\n",
" {\n",
" \"id\": f\"call_{random.randint(1000, 10000)}\",\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": tool_content[\"name\"],\n",
" \"arguments\": tool_content[\"arguments\"],\n",
" },\n",
" }\n",
" )\n",
"\n",
" flag = False\n",
"\n",
" return tool_calls\n",
"\n",
"def fix_json_string(json_str: str):\n",
" # Remove any leading or trailing whitespace or newline characters\n",
" json_str = json_str.strip()\n",
"\n",
" # Stack to keep track of brackets\n",
" stack = []\n",
"\n",
" # Clean string to collect valid characters\n",
" fixed_str = \"\"\n",
"\n",
" # Dictionary for matching brackets\n",
" matching_bracket = {\")\": \"(\", \"}\": \"{\", \"]\": \"[\"}\n",
"\n",
" # Dictionary for the opposite of matching_bracket\n",
" opening_bracket = {v: k for k, v in matching_bracket.items()}\n",
"\n",
" for char in json_str:\n",
" if char in \"{[(\":\n",
" stack.append(char)\n",
" fixed_str += char\n",
" elif char in \"}])\":\n",
" if stack and stack[-1] == matching_bracket[char]:\n",
" stack.pop()\n",
" fixed_str += char\n",
" else:\n",
" # Ignore the unmatched closing brackets\n",
" continue\n",
" else:\n",
" fixed_str += char\n",
"\n",
" # If there are unmatched opening brackets left in the stack, add corresponding closing brackets\n",
" while stack:\n",
" unmatched_opening = stack.pop()\n",
" fixed_str += opening_bracket[unmatched_opening]\n",
"\n",
" # Attempt to parse the corrected string to ensure its valid JSON\n",
" return fixed_str.replace(\"\\'\", \"\\\"\")\n",
"\n",
"\n",
"TASK_PROMPT = \"\"\"\n",
"You are a helpful assistant.\n",
"\"\"\".strip()\n",
"\n",
"TOOL_PROMPT = \"\"\"\n",
"# Tools\n",
"\n",
"You may call one or more functions to assist with the user query.\n",
"\n",
"You are provided with function signatures within <tools></tools> XML tags:\n",
"<tools>\n",
"{tool_text}\n",
"</tools>\n",
"\"\"\".strip()\n",
"\n",
"FORMAT_PROMPT = \"\"\"\n",
"For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n",
"<tool_call>\n",
"{\"name\": <function-name>, \"arguments\": <args-json-object>}\n",
"</tool_call>\n",
"\"\"\".strip()\n",
"\n",
"\n",
"get_weather_api = {\n",
"\t\"type\": \"function\",\n",
"\t\"function\": {\n",
"\t\t\"name\": \"get_current_weather\",\n",
"\t\t\"description\": \"Get current weather at a location.\",\n",
"\t\t\"parameters\": {\n",
"\t\t\t\"type\": \"object\",\n",
"\t\t\t\"properties\": {\n",
"\t\t\t\t\"location\": {\n",
"\t\t\t\t\t\"type\": \"str\",\n",
"\t\t\t\t\t\"description\": \"The location to get the weather for\",\n",
" \"format\": \"City, State, Country\",\n",
"\t\t\t\t},\n",
"\t\t\t\t\"unit\": {\n",
"\t\t\t\t\t\"type\": \"str\",\n",
" \"description\": \"The unit to return the weather in.\",\n",
"\t\t\t\t\t\"enum\": [\n",
"\t\t\t\t\t\t\"celsius\",\n",
"\t\t\t\t\t\t\"fahrenheit\"\n",
"\t\t\t\t\t],\n",
" \"default\": \"celsius\"\n",
"\t\t\t\t},\n",
" \"days\": {\n",
"\t\t\t\t\t\"type\": \"str\",\n",
" \"description\": \"the number of days for the request.\",\n",
"\t\t\t\t}\n",
"\t\t\t},\n",
"\t\t\t\"required\": [\n",
"\t\t\t\t\"location\",\n",
" \"days\"\n",
"\t\t\t]\n",
"\t\t}\n",
"\t}\n",
"}\n",
"def check_parameter_property(api_description, parameter_name, property_name):\n",
" \"\"\"\n",
" Check if a parameter in an API description has a specific property.\n",
"\n",
" Args:\n",
" api_description (dict): The API description in JSON format.\n",
" parameter_name (str): The name of the parameter to check.\n",
" property_name (str): The property to look for (e.g., 'format', 'default').\n",
"\n",
" Returns:\n",
" bool: True if the parameter has the specified property, False otherwise.\n",
" \"\"\"\n",
" parameters = api_description.get(\"parameters\", {}).get(\"properties\", {})\n",
" parameter_info = parameters.get(parameter_name, {})\n",
"\n",
" return property_name in parameter_info\n",
"\n",
"\n",
"# Example usage\n",
"\n",
"def convert_tools(tools: List[Dict[str, Any]]):\n",
" return \"\\n\".join([json.dumps(tool) for tool in tools])\n",
"\n",
"# Helper function to create the system prompt for our model\n",
"def format_prompt(tools: List[Dict[str, Any]]):\n",
" tool_text = convert_tools(tools)\n",
"\n",
" return (\n",
" TASK_PROMPT\n",
" + \"\\n\\n\"\n",
" + TOOL_PROMPT.format(tool_text=tool_text)\n",
" + \"\\n\\n\"\n",
" + FORMAT_PROMPT\n",
" + \"\\n\"\n",
" )\n",
"\n",
"openai_format_tools = [get_weather_api]\n",
"\n",
"system_prompt = format_prompt(openai_format_tools)\n",
"\n",
"\n",
"from openai import OpenAI\n",
"\n",
"client = OpenAI(base_url=\"https://api.fc.archgw.com/v1\", api_key=\"EMPTY\")\n",
"\n",
"# List models API\n",
"model = client.models.list().data[0].id\n",
"print(\"Model:\", model)\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" # {\"role\": \"user\", \"content\": \"can you help me check weather?\"},\n",
" {\"role\": \"user\", \"content\": \"How is the weather in Seattle in 7 days?\"},\n",
" #{\"role\": \"assistant\", \"content\": \"Of course!\"},\n",
" # {\"role\": \"user\", \"content\": \"Seattle please\"}\n",
"]\n",
"\n",
"extra_body = {\n",
" \"temperature\": 0.6,\n",
" \"top_p\": 1.0,\n",
" \"top_k\": 50,\n",
" # \"continue_final_message\": True,\n",
" # \"add_generation_prompt\": False,\n",
" \"logprobs\": True,\n",
" \"top_logprobs\": 10\n",
"}\n",
"\n",
"resp = client.chat.completions.create(\n",
" model=\"Arch-Function\",\n",
" messages=messages,\n",
" extra_body=extra_body,\n",
" stream = True\n",
")\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n",
"The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n"
]
}
],
"source": [
"import json\n",
"from app.function_calling.hallucination_handler import HallucinationStateHandler\n",
"import pytest\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"get_weather_api = {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get current weather at a location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"str\",\n",
" \"description\": \"The location to get the weather for\",\n",
" \"format\": \"City, State\",\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"str\",\n",
" \"description\": \"The unit to return the weather in.\",\n",
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
" \"default\": \"celsius\",\n",
" },\n",
" \"days\": {\n",
" \"type\": \"str\",\n",
" \"description\": \"the number of days for the request.\",\n",
" },\n",
" },\n",
" \"required\": [\"location\", \"days\"],\n",
" },\n",
" },\n",
"}\n",
"function_description = get_weather_api[\"function\"]\n",
"if type(function_description) != list:\n",
" function_description = [get_weather_api[\"function\"]]\n",
"\n",
"hallu = HallucinationStateHandler(response_iterator=resp, apis = function_description)\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"False\n"
]
}
],
"source": [
"r = next(hallu.response_iterator)\n",
"if hasattr(r.choices[0].delta, \"content\"):\n",
" token_content = r.choices[0].delta.content\n",
" if token_content:\n",
" logprobs = [p.logprob for p in r.choices[0].logprobs.content[0].top_logprobs]\n",
" print(hallu.check_token_hallucination(token_content, logprobs), hallu.tokens)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Token: None\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: <tool_call>\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \n",
"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: {\"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: name\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \":\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"function name entered\n",
"Token: \"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: get\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: _current\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: _weather\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \",\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: arguments\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \":\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: {\"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: location\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \":\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: Seattle\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: ,\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: WA\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: ,\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: USA\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \",\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: days\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \":\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: 7\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \"}}\n",
"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: </tool_call>\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: None\n",
"Hallucination: False\n",
"Hallucination Message: \n"
]
}
],
"source": [
"for token in hallu:\n",
" print(f\"Token: {token}\")\n",
" print(f\"Hallucination: {hallu.hallucination}\")\n",
" print(f\"Hallucination Message: {hallu.hallucination_message}\")\n",
"\n",
"# Access the tokens processed"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"for token in resp:\n",
" if hasattr(token.choices[0].delta, \"content\"):\n",
" \n",
" token_content = token.choices[0].delta.content\n",
" print(token_content)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('<tool_call>', 0.0034013802651315928, -0.0032856864854693413),\n",
" ('Seattle', 1.1522446584422141e-05, -1.1521117812662851e-05),\n",
" ('7', 1.553274159959983e-05, -1.5530327800661325e-05)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hallu.token_probs_map"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}