plano/model_server/app/tests/test.ipynb
2024-12-06 15:50:03 -08:00

513 lines
16 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/co-tran/Documents/arch/new_venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: Arch-Function\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import json\n",
"import ast\n",
"from datasets import load_dataset\n",
"import re\n",
"import subprocess\n",
"import threading\n",
"import requests\n",
"import time\n",
"from concurrent.futures import ThreadPoolExecutor\n",
"from tqdm import tqdm\n",
"import os\n",
"import json\n",
"import math\n",
"import torch\n",
"\n",
"from datasets import load_dataset\n",
"from typing import Any, Dict, List, Tuple\n",
"\n",
"from typing import Any, Dict, List\n",
"\n",
"def extract_tool_calls(content: str):\n",
" tool_calls = []\n",
"\n",
" flag = False\n",
" for line in content.split(\"\\n\"):\n",
" if \"<tool_call>\" == line:\n",
" flag = True\n",
" elif \"</tool_call>\" == line:\n",
" flag = False\n",
" else:\n",
" if flag:\n",
" try:\n",
" tool_content = json.loads(line)\n",
" except Exception as e:\n",
" print(e)\n",
" fixed_content = fix_json_string(line)\n",
" try:\n",
" tool_content = json.loads(fixed_content)\n",
" except json.JSONDecodeError:\n",
" print(\"json.JSONDecodeError\")\n",
" return content\n",
"\n",
" tool_calls.append(\n",
" {\n",
" \"id\": f\"call_{random.randint(1000, 10000)}\",\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": tool_content[\"name\"],\n",
" \"arguments\": tool_content[\"arguments\"],\n",
" },\n",
" }\n",
" )\n",
"\n",
" flag = False\n",
"\n",
" return tool_calls\n",
"\n",
"def fix_json_string(json_str: str):\n",
" # Remove any leading or trailing whitespace or newline characters\n",
" json_str = json_str.strip()\n",
"\n",
" # Stack to keep track of brackets\n",
" stack = []\n",
"\n",
" # Clean string to collect valid characters\n",
" fixed_str = \"\"\n",
"\n",
" # Dictionary for matching brackets\n",
" matching_bracket = {\")\": \"(\", \"}\": \"{\", \"]\": \"[\"}\n",
"\n",
" # Dictionary for the opposite of matching_bracket\n",
" opening_bracket = {v: k for k, v in matching_bracket.items()}\n",
"\n",
" for char in json_str:\n",
" if char in \"{[(\":\n",
" stack.append(char)\n",
" fixed_str += char\n",
" elif char in \"}])\":\n",
" if stack and stack[-1] == matching_bracket[char]:\n",
" stack.pop()\n",
" fixed_str += char\n",
" else:\n",
" # Ignore the unmatched closing brackets\n",
" continue\n",
" else:\n",
" fixed_str += char\n",
"\n",
" # If there are unmatched opening brackets left in the stack, add corresponding closing brackets\n",
" while stack:\n",
" unmatched_opening = stack.pop()\n",
" fixed_str += opening_bracket[unmatched_opening]\n",
"\n",
" # Attempt to parse the corrected string to ensure its valid JSON\n",
" return fixed_str.replace(\"\\'\", \"\\\"\")\n",
"\n",
"\n",
"TASK_PROMPT = \"\"\"\n",
"You are a helpful assistant.\n",
"\"\"\".strip()\n",
"\n",
"TOOL_PROMPT = \"\"\"\n",
"# Tools\n",
"\n",
"You may call one or more functions to assist with the user query.\n",
"\n",
"You are provided with function signatures within <tools></tools> XML tags:\n",
"<tools>\n",
"{tool_text}\n",
"</tools>\n",
"\"\"\".strip()\n",
"\n",
"FORMAT_PROMPT = \"\"\"\n",
"For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n",
"<tool_call>\n",
"{\"name\": <function-name>, \"arguments\": <args-json-object>}\n",
"</tool_call>\n",
"\"\"\".strip()\n",
"\n",
"\n",
"get_weather_api = {\n",
"\t\"type\": \"function\",\n",
"\t\"function\": {\n",
"\t\t\"name\": \"get_current_weather\",\n",
"\t\t\"description\": \"Get current weather at a location.\",\n",
"\t\t\"parameters\": {\n",
"\t\t\t\"type\": \"object\",\n",
"\t\t\t\"properties\": {\n",
"\t\t\t\t\"location\": {\n",
"\t\t\t\t\t\"type\": \"str\",\n",
"\t\t\t\t\t\"description\": \"The location to get the weather for\",\n",
" \"format\": \"City, State, Country\",\n",
"\t\t\t\t},\n",
"\t\t\t\t\"unit\": {\n",
"\t\t\t\t\t\"type\": \"str\",\n",
" \"description\": \"The unit to return the weather in.\",\n",
"\t\t\t\t\t\"enum\": [\n",
"\t\t\t\t\t\t\"celsius\",\n",
"\t\t\t\t\t\t\"fahrenheit\"\n",
"\t\t\t\t\t],\n",
" \"default\": \"celsius\"\n",
"\t\t\t\t},\n",
" \"days\": {\n",
"\t\t\t\t\t\"type\": \"str\",\n",
" \"description\": \"the number of days for the request.\",\n",
"\t\t\t\t}\n",
"\t\t\t},\n",
"\t\t\t\"required\": [\n",
"\t\t\t\t\"location\",\n",
" \"days\"\n",
"\t\t\t]\n",
"\t\t}\n",
"\t}\n",
"}\n",
"def check_parameter_property(api_description, parameter_name, property_name):\n",
" \"\"\"\n",
" Check if a parameter in an API description has a specific property.\n",
"\n",
" Args:\n",
" api_description (dict): The API description in JSON format.\n",
" parameter_name (str): The name of the parameter to check.\n",
" property_name (str): The property to look for (e.g., 'format', 'default').\n",
"\n",
" Returns:\n",
" bool: True if the parameter has the specified property, False otherwise.\n",
" \"\"\"\n",
" parameters = api_description.get(\"parameters\", {}).get(\"properties\", {})\n",
" parameter_info = parameters.get(parameter_name, {})\n",
"\n",
" return property_name in parameter_info\n",
"\n",
"\n",
"# Example usage\n",
"\n",
"def convert_tools(tools: List[Dict[str, Any]]):\n",
" return \"\\n\".join([json.dumps(tool) for tool in tools])\n",
"\n",
"# Helper function to create the system prompt for our model\n",
"def format_prompt(tools: List[Dict[str, Any]]):\n",
" tool_text = convert_tools(tools)\n",
"\n",
" return (\n",
" TASK_PROMPT\n",
" + \"\\n\\n\"\n",
" + TOOL_PROMPT.format(tool_text=tool_text)\n",
" + \"\\n\\n\"\n",
" + FORMAT_PROMPT\n",
" + \"\\n\"\n",
" )\n",
"\n",
"openai_format_tools = [get_weather_api]\n",
"\n",
"system_prompt = format_prompt(openai_format_tools)\n",
"\n",
"\n",
"from openai import OpenAI\n",
"\n",
"client = OpenAI(base_url=\"https://api.fc.archgw.com/v1\", api_key=\"EMPTY\")\n",
"\n",
"# List models API\n",
"model = client.models.list().data[0].id\n",
"print(\"Model:\", model)\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" # {\"role\": \"user\", \"content\": \"can you help me check weather?\"},\n",
" {\"role\": \"user\", \"content\": \"How is the weather in Seattle in 7 days?\"},\n",
" #{\"role\": \"assistant\", \"content\": \"Of course!\"},\n",
" # {\"role\": \"user\", \"content\": \"Seattle please\"}\n",
"]\n",
"\n",
"extra_body = {\n",
" \"temperature\": 0.6,\n",
" \"top_p\": 1.0,\n",
" \"top_k\": 50,\n",
" # \"continue_final_message\": True,\n",
" # \"add_generation_prompt\": False,\n",
" \"logprobs\": True,\n",
" \"top_logprobs\": 10\n",
"}\n",
"\n",
"resp = client.chat.completions.create(\n",
" model=\"Arch-Function\",\n",
" messages=messages,\n",
" extra_body=extra_body,\n",
" stream = True\n",
")\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n",
"The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n"
]
}
],
"source": [
"import json\n",
"from app.function_calling.hallucination_handler import HallucinationStateHandler\n",
"import pytest\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"get_weather_api = {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get current weather at a location.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"str\",\n",
" \"description\": \"The location to get the weather for\",\n",
" \"format\": \"City, State\",\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"str\",\n",
" \"description\": \"The unit to return the weather in.\",\n",
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
" \"default\": \"celsius\",\n",
" },\n",
" \"days\": {\n",
" \"type\": \"str\",\n",
" \"description\": \"the number of days for the request.\",\n",
" },\n",
" },\n",
" \"required\": [\"location\", \"days\"],\n",
" },\n",
" },\n",
"}\n",
"function_description = get_weather_api[\"function\"]\n",
"if type(function_description) != list:\n",
" function_description = [get_weather_api[\"function\"]]\n",
"\n",
"hallu = HallucinationStateHandler(response_iterator=resp, apis = function_description)\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"False\n"
]
}
],
"source": [
"r = next(hallu.response_iterator)\n",
"if hasattr(r.choices[0].delta, \"content\"):\n",
" token_content = r.choices[0].delta.content\n",
" if token_content:\n",
" logprobs = [p.logprob for p in r.choices[0].logprobs.content[0].top_logprobs]\n",
" print(hallu.check_token_hallucination(token_content, logprobs), hallu.tokens)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Token: None\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: <tool_call>\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \n",
"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: {\"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: name\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \":\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"function name entered\n",
"Token: \"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: get\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: _current\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: _weather\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \",\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: arguments\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \":\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: {\"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: location\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \":\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: Seattle\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: ,\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: WA\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: ,\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: USA\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \",\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: days\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \":\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: 7\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: \"}}\n",
"\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: </tool_call>\n",
"Hallucination: False\n",
"Hallucination Message: \n",
"Token: None\n",
"Hallucination: False\n",
"Hallucination Message: \n"
]
}
],
"source": [
"for token in hallu:\n",
" print(f\"Token: {token}\")\n",
" print(f\"Hallucination: {hallu.hallucination}\")\n",
" print(f\"Hallucination Message: {hallu.hallucination_message}\")\n",
"\n",
"# Access the tokens processed"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"for token in resp:\n",
" if hasattr(token.choices[0].delta, \"content\"):\n",
" \n",
" token_content = token.choices[0].delta.content\n",
" print(token_content)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('<tool_call>', 0.0034013802651315928, -0.0032856864854693413),\n",
" ('Seattle', 1.1522446584422141e-05, -1.1521117812662851e-05),\n",
" ('7', 1.553274159959983e-05, -1.5530327800661325e-05)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hallu.token_probs_map"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}