mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
513 lines
16 KiB
Text
513 lines
16 KiB
Text
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/Users/co-tran/Documents/arch/new_venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: Arch-Function\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"import json\n",
|
||
"import ast\n",
|
||
"from datasets import load_dataset\n",
|
||
"import re\n",
|
||
"import subprocess\n",
|
||
"import threading\n",
|
||
"import requests\n",
|
||
"import time\n",
|
||
"from concurrent.futures import ThreadPoolExecutor\n",
|
||
"from tqdm import tqdm\n",
|
||
"import os\n",
|
||
"import json\n",
|
||
"import math\n",
|
||
"import torch\n",
|
||
"\n",
|
||
"from datasets import load_dataset\n",
|
||
"from typing import Any, Dict, List, Tuple\n",
|
||
"\n",
|
||
"from typing import Any, Dict, List\n",
|
||
"\n",
|
||
"def extract_tool_calls(content: str):\n",
|
||
" tool_calls = []\n",
|
||
"\n",
|
||
" flag = False\n",
|
||
" for line in content.split(\"\\n\"):\n",
|
||
" if \"<tool_call>\" == line:\n",
|
||
" flag = True\n",
|
||
" elif \"</tool_call>\" == line:\n",
|
||
" flag = False\n",
|
||
" else:\n",
|
||
" if flag:\n",
|
||
" try:\n",
|
||
" tool_content = json.loads(line)\n",
|
||
" except Exception as e:\n",
|
||
" print(e)\n",
|
||
" fixed_content = fix_json_string(line)\n",
|
||
" try:\n",
|
||
" tool_content = json.loads(fixed_content)\n",
|
||
" except json.JSONDecodeError:\n",
|
||
" print(\"json.JSONDecodeError\")\n",
|
||
" return content\n",
|
||
"\n",
|
||
" tool_calls.append(\n",
|
||
" {\n",
|
||
" \"id\": f\"call_{random.randint(1000, 10000)}\",\n",
|
||
" \"type\": \"function\",\n",
|
||
" \"function\": {\n",
|
||
" \"name\": tool_content[\"name\"],\n",
|
||
" \"arguments\": tool_content[\"arguments\"],\n",
|
||
" },\n",
|
||
" }\n",
|
||
" )\n",
|
||
"\n",
|
||
" flag = False\n",
|
||
"\n",
|
||
" return tool_calls\n",
|
||
"\n",
|
||
"def fix_json_string(json_str: str):\n",
|
||
" # Remove any leading or trailing whitespace or newline characters\n",
|
||
" json_str = json_str.strip()\n",
|
||
"\n",
|
||
" # Stack to keep track of brackets\n",
|
||
" stack = []\n",
|
||
"\n",
|
||
" # Clean string to collect valid characters\n",
|
||
" fixed_str = \"\"\n",
|
||
"\n",
|
||
" # Dictionary for matching brackets\n",
|
||
" matching_bracket = {\")\": \"(\", \"}\": \"{\", \"]\": \"[\"}\n",
|
||
"\n",
|
||
" # Dictionary for the opposite of matching_bracket\n",
|
||
" opening_bracket = {v: k for k, v in matching_bracket.items()}\n",
|
||
"\n",
|
||
" for char in json_str:\n",
|
||
" if char in \"{[(\":\n",
|
||
" stack.append(char)\n",
|
||
" fixed_str += char\n",
|
||
" elif char in \"}])\":\n",
|
||
" if stack and stack[-1] == matching_bracket[char]:\n",
|
||
" stack.pop()\n",
|
||
" fixed_str += char\n",
|
||
" else:\n",
|
||
" # Ignore the unmatched closing brackets\n",
|
||
" continue\n",
|
||
" else:\n",
|
||
" fixed_str += char\n",
|
||
"\n",
|
||
" # If there are unmatched opening brackets left in the stack, add corresponding closing brackets\n",
|
||
" while stack:\n",
|
||
" unmatched_opening = stack.pop()\n",
|
||
" fixed_str += opening_bracket[unmatched_opening]\n",
|
||
"\n",
|
||
" # Attempt to parse the corrected string to ensure it’s valid JSON\n",
|
||
" return fixed_str.replace(\"\\'\", \"\\\"\")\n",
|
||
"\n",
|
||
"\n",
|
||
"TASK_PROMPT = \"\"\"\n",
|
||
"You are a helpful assistant.\n",
|
||
"\"\"\".strip()\n",
|
||
"\n",
|
||
"TOOL_PROMPT = \"\"\"\n",
|
||
"# Tools\n",
|
||
"\n",
|
||
"You may call one or more functions to assist with the user query.\n",
|
||
"\n",
|
||
"You are provided with function signatures within <tools></tools> XML tags:\n",
|
||
"<tools>\n",
|
||
"{tool_text}\n",
|
||
"</tools>\n",
|
||
"\"\"\".strip()\n",
|
||
"\n",
|
||
"FORMAT_PROMPT = \"\"\"\n",
|
||
"For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n",
|
||
"<tool_call>\n",
|
||
"{\"name\": <function-name>, \"arguments\": <args-json-object>}\n",
|
||
"</tool_call>\n",
|
||
"\"\"\".strip()\n",
|
||
"\n",
|
||
"\n",
|
||
"get_weather_api = {\n",
|
||
"\t\"type\": \"function\",\n",
|
||
"\t\"function\": {\n",
|
||
"\t\t\"name\": \"get_current_weather\",\n",
|
||
"\t\t\"description\": \"Get current weather at a location.\",\n",
|
||
"\t\t\"parameters\": {\n",
|
||
"\t\t\t\"type\": \"object\",\n",
|
||
"\t\t\t\"properties\": {\n",
|
||
"\t\t\t\t\"location\": {\n",
|
||
"\t\t\t\t\t\"type\": \"str\",\n",
|
||
"\t\t\t\t\t\"description\": \"The location to get the weather for\",\n",
|
||
" \"format\": \"City, State, Country\",\n",
|
||
"\t\t\t\t},\n",
|
||
"\t\t\t\t\"unit\": {\n",
|
||
"\t\t\t\t\t\"type\": \"str\",\n",
|
||
" \"description\": \"The unit to return the weather in.\",\n",
|
||
"\t\t\t\t\t\"enum\": [\n",
|
||
"\t\t\t\t\t\t\"celsius\",\n",
|
||
"\t\t\t\t\t\t\"fahrenheit\"\n",
|
||
"\t\t\t\t\t],\n",
|
||
" \"default\": \"celsius\"\n",
|
||
"\t\t\t\t},\n",
|
||
" \"days\": {\n",
|
||
"\t\t\t\t\t\"type\": \"str\",\n",
|
||
" \"description\": \"the number of days for the request.\",\n",
|
||
"\t\t\t\t}\n",
|
||
"\t\t\t},\n",
|
||
"\t\t\t\"required\": [\n",
|
||
"\t\t\t\t\"location\",\n",
|
||
" \"days\"\n",
|
||
"\t\t\t]\n",
|
||
"\t\t}\n",
|
||
"\t}\n",
|
||
"}\n",
|
||
"def check_parameter_property(api_description, parameter_name, property_name):\n",
|
||
" \"\"\"\n",
|
||
" Check if a parameter in an API description has a specific property.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" api_description (dict): The API description in JSON format.\n",
|
||
" parameter_name (str): The name of the parameter to check.\n",
|
||
" property_name (str): The property to look for (e.g., 'format', 'default').\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" bool: True if the parameter has the specified property, False otherwise.\n",
|
||
" \"\"\"\n",
|
||
" parameters = api_description.get(\"parameters\", {}).get(\"properties\", {})\n",
|
||
" parameter_info = parameters.get(parameter_name, {})\n",
|
||
"\n",
|
||
" return property_name in parameter_info\n",
|
||
"\n",
|
||
"\n",
|
||
"# Example usage\n",
|
||
"\n",
|
||
"def convert_tools(tools: List[Dict[str, Any]]):\n",
|
||
" return \"\\n\".join([json.dumps(tool) for tool in tools])\n",
|
||
"\n",
|
||
"# Helper function to create the system prompt for our model\n",
|
||
"def format_prompt(tools: List[Dict[str, Any]]):\n",
|
||
" tool_text = convert_tools(tools)\n",
|
||
"\n",
|
||
" return (\n",
|
||
" TASK_PROMPT\n",
|
||
" + \"\\n\\n\"\n",
|
||
" + TOOL_PROMPT.format(tool_text=tool_text)\n",
|
||
" + \"\\n\\n\"\n",
|
||
" + FORMAT_PROMPT\n",
|
||
" + \"\\n\"\n",
|
||
" )\n",
|
||
"\n",
|
||
"openai_format_tools = [get_weather_api]\n",
|
||
"\n",
|
||
"system_prompt = format_prompt(openai_format_tools)\n",
|
||
"\n",
|
||
"\n",
|
||
"from openai import OpenAI\n",
|
||
"\n",
|
||
"client = OpenAI(base_url=\"https://api.fc.archgw.com/v1\", api_key=\"EMPTY\")\n",
|
||
"\n",
|
||
"# List models API\n",
|
||
"model = client.models.list().data[0].id\n",
|
||
"print(\"Model:\", model)\n",
|
||
"messages = [\n",
|
||
" {\"role\": \"system\", \"content\": system_prompt},\n",
|
||
" # {\"role\": \"user\", \"content\": \"can you help me check weather?\"},\n",
|
||
" {\"role\": \"user\", \"content\": \"How is the weather in Seattle in 7 days?\"},\n",
|
||
" #{\"role\": \"assistant\", \"content\": \"Of course!\"},\n",
|
||
" # {\"role\": \"user\", \"content\": \"Seattle please\"}\n",
|
||
"]\n",
|
||
"\n",
|
||
"extra_body = {\n",
|
||
" \"temperature\": 0.6,\n",
|
||
" \"top_p\": 1.0,\n",
|
||
" \"top_k\": 50,\n",
|
||
" # \"continue_final_message\": True,\n",
|
||
" # \"add_generation_prompt\": False,\n",
|
||
" \"logprobs\": True,\n",
|
||
" \"top_logprobs\": 10\n",
|
||
"}\n",
|
||
"\n",
|
||
"resp = client.chat.completions.create(\n",
|
||
" model=\"Arch-Function\",\n",
|
||
" messages=messages,\n",
|
||
" extra_body=extra_body,\n",
|
||
" stream = True\n",
|
||
")\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n",
|
||
"The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import json\n",
|
||
"from app.function_calling.hallucination_handler import HallucinationStateHandler\n",
|
||
"import pytest\n",
|
||
"import os"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"get_weather_api = {\n",
|
||
" \"type\": \"function\",\n",
|
||
" \"function\": {\n",
|
||
" \"name\": \"get_current_weather\",\n",
|
||
" \"description\": \"Get current weather at a location.\",\n",
|
||
" \"parameters\": {\n",
|
||
" \"type\": \"object\",\n",
|
||
" \"properties\": {\n",
|
||
" \"location\": {\n",
|
||
" \"type\": \"str\",\n",
|
||
" \"description\": \"The location to get the weather for\",\n",
|
||
" \"format\": \"City, State\",\n",
|
||
" },\n",
|
||
" \"unit\": {\n",
|
||
" \"type\": \"str\",\n",
|
||
" \"description\": \"The unit to return the weather in.\",\n",
|
||
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
|
||
" \"default\": \"celsius\",\n",
|
||
" },\n",
|
||
" \"days\": {\n",
|
||
" \"type\": \"str\",\n",
|
||
" \"description\": \"the number of days for the request.\",\n",
|
||
" },\n",
|
||
" },\n",
|
||
" \"required\": [\"location\", \"days\"],\n",
|
||
" },\n",
|
||
" },\n",
|
||
"}\n",
|
||
"function_description = get_weather_api[\"function\"]\n",
|
||
"if type(function_description) != list:\n",
|
||
" function_description = [get_weather_api[\"function\"]]\n",
|
||
"\n",
|
||
"hallu = HallucinationStateHandler(response_iterator=resp, apis = function_description)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"False\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"r = next(hallu.response_iterator)\n",
|
||
"if hasattr(r.choices[0].delta, \"content\"):\n",
|
||
" token_content = r.choices[0].delta.content\n",
|
||
" if token_content:\n",
|
||
" logprobs = [p.logprob for p in r.choices[0].logprobs.content[0].top_logprobs]\n",
|
||
" print(hallu.check_token_hallucination(token_content, logprobs), hallu.tokens)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Token: None\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: <tool_call>\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: \n",
|
||
"\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: {\"\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: name\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: \":\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"function name entered\n",
|
||
"Token: \"\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: get\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: _current\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: _weather\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: \",\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: \"\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: arguments\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: \":\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: {\"\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: location\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: \":\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: \"\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: Seattle\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: ,\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: WA\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: ,\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: USA\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: \",\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: \"\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: days\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: \":\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: \"\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: 7\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: \"}}\n",
|
||
"\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: </tool_call>\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n",
|
||
"Token: None\n",
|
||
"Hallucination: False\n",
|
||
"Hallucination Message: \n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"for token in hallu:\n",
|
||
" print(f\"Token: {token}\")\n",
|
||
" print(f\"Hallucination: {hallu.hallucination}\")\n",
|
||
" print(f\"Hallucination Message: {hallu.hallucination_message}\")\n",
|
||
"\n",
|
||
"# Access the tokens processed"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for token in resp:\n",
|
||
" if hasattr(token.choices[0].delta, \"content\"):\n",
|
||
" \n",
|
||
" token_content = token.choices[0].delta.content\n",
|
||
" print(token_content)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[('<tool_call>', 0.0034013802651315928, -0.0032856864854693413),\n",
|
||
" ('Seattle', 1.1522446584422141e-05, -1.1521117812662851e-05),\n",
|
||
" ('7', 1.553274159959983e-05, -1.5530327800661325e-05)]"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"hallu.token_probs_map"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"language_info": {
|
||
"name": "python"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|