From 67e458bc5af28a73464ded9a8598d498665bff69 Mon Sep 17 00:00:00 2001 From: cotran Date: Fri, 6 Dec 2024 15:57:06 -0800 Subject: [PATCH] remove test --- model_server/app/tests/test.ipynb | 513 ------------------------------ 1 file changed, 513 deletions(-) delete mode 100644 model_server/app/tests/test.ipynb diff --git a/model_server/app/tests/test.ipynb b/model_server/app/tests/test.ipynb deleted file mode 100644 index 1dfd455e..00000000 --- a/model_server/app/tests/test.ipynb +++ /dev/null @@ -1,513 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/co-tran/Documents/arch/new_venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model: Arch-Function\n" - ] - } - ], - "source": [ - "import pandas as pd\n", - "import numpy as np\n", - "import json\n", - "import ast\n", - "from datasets import load_dataset\n", - "import re\n", - "import subprocess\n", - "import threading\n", - "import requests\n", - "import time\n", - "from concurrent.futures import ThreadPoolExecutor\n", - "from tqdm import tqdm\n", - "import os\n", - "import json\n", - "import math\n", - "import torch\n", - "\n", - "from datasets import load_dataset\n", - "from typing import Any, Dict, List, Tuple\n", - "\n", - "from typing import Any, Dict, List\n", - "\n", - "def extract_tool_calls(content: str):\n", - " tool_calls = []\n", - "\n", - " flag = False\n", - " for line in content.split(\"\\n\"):\n", - " if \"\" == line:\n", - " flag = True\n", - " elif \"\" == line:\n", - " flag = False\n", - " else:\n", - " if flag:\n", - " try:\n", - " tool_content = json.loads(line)\n", - " except Exception as e:\n", - " print(e)\n", - " fixed_content = fix_json_string(line)\n", - " try:\n", - " tool_content = json.loads(fixed_content)\n", - " except json.JSONDecodeError:\n", - " print(\"json.JSONDecodeError\")\n", - " return content\n", - "\n", - " tool_calls.append(\n", - " {\n", - " \"id\": f\"call_{random.randint(1000, 10000)}\",\n", - " \"type\": \"function\",\n", - " \"function\": {\n", - " \"name\": tool_content[\"name\"],\n", - " \"arguments\": tool_content[\"arguments\"],\n", - " },\n", - " }\n", - " )\n", - "\n", - " flag = False\n", - "\n", - " return tool_calls\n", - "\n", - "def fix_json_string(json_str: str):\n", - " # Remove any leading or trailing whitespace or newline characters\n", - " json_str = json_str.strip()\n", - "\n", - " # Stack to keep track of brackets\n", - " stack = []\n", - "\n", - " # Clean string to collect valid characters\n", - " fixed_str = \"\"\n", - "\n", - " # Dictionary for matching brackets\n", - " matching_bracket = {\")\": \"(\", \"}\": \"{\", \"]\": \"[\"}\n", - "\n", - " # Dictionary for the opposite of matching_bracket\n", - " opening_bracket = {v: k for k, v in matching_bracket.items()}\n", - "\n", - " for char in json_str:\n", - " if char in \"{[(\":\n", - " stack.append(char)\n", - " fixed_str += char\n", - " elif char in \"}])\":\n", - " if stack and stack[-1] == matching_bracket[char]:\n", - " stack.pop()\n", - " fixed_str += char\n", - " else:\n", - " # Ignore the unmatched closing brackets\n", - " continue\n", - " else:\n", - " fixed_str += char\n", - "\n", - " # If there are unmatched opening brackets left in the stack, add corresponding closing brackets\n", - " while stack:\n", - " unmatched_opening = stack.pop()\n", - " fixed_str += opening_bracket[unmatched_opening]\n", - "\n", - " # Attempt to parse the corrected string to ensure it’s valid JSON\n", - " return fixed_str.replace(\"\\'\", \"\\\"\")\n", - "\n", - "\n", - "TASK_PROMPT = \"\"\"\n", - "You are a helpful assistant.\n", - "\"\"\".strip()\n", - "\n", - "TOOL_PROMPT = \"\"\"\n", - "# Tools\n", - "\n", - "You may call one or more functions to assist with the user query.\n", - "\n", - "You are provided with function signatures within XML tags:\n", - "\n", - "{tool_text}\n", - "\n", - "\"\"\".strip()\n", - "\n", - "FORMAT_PROMPT = \"\"\"\n", - "For each function call, return a json object with function name and arguments within XML tags:\n", - "\n", - "{\"name\": , \"arguments\": }\n", - "\n", - "\"\"\".strip()\n", - "\n", - "\n", - "get_weather_api = {\n", - "\t\"type\": \"function\",\n", - "\t\"function\": {\n", - "\t\t\"name\": \"get_current_weather\",\n", - "\t\t\"description\": \"Get current weather at a location.\",\n", - "\t\t\"parameters\": {\n", - "\t\t\t\"type\": \"object\",\n", - "\t\t\t\"properties\": {\n", - "\t\t\t\t\"location\": {\n", - "\t\t\t\t\t\"type\": \"str\",\n", - "\t\t\t\t\t\"description\": \"The location to get the weather for\",\n", - " \"format\": \"City, State, Country\",\n", - "\t\t\t\t},\n", - "\t\t\t\t\"unit\": {\n", - "\t\t\t\t\t\"type\": \"str\",\n", - " \"description\": \"The unit to return the weather in.\",\n", - "\t\t\t\t\t\"enum\": [\n", - "\t\t\t\t\t\t\"celsius\",\n", - "\t\t\t\t\t\t\"fahrenheit\"\n", - "\t\t\t\t\t],\n", - " \"default\": \"celsius\"\n", - "\t\t\t\t},\n", - " \"days\": {\n", - "\t\t\t\t\t\"type\": \"str\",\n", - " \"description\": \"the number of days for the request.\",\n", - "\t\t\t\t}\n", - "\t\t\t},\n", - "\t\t\t\"required\": [\n", - "\t\t\t\t\"location\",\n", - " \"days\"\n", - "\t\t\t]\n", - "\t\t}\n", - "\t}\n", - "}\n", - "def check_parameter_property(api_description, parameter_name, property_name):\n", - " \"\"\"\n", - " Check if a parameter in an API description has a specific property.\n", - "\n", - " Args:\n", - " api_description (dict): The API description in JSON format.\n", - " parameter_name (str): The name of the parameter to check.\n", - " property_name (str): The property to look for (e.g., 'format', 'default').\n", - "\n", - " Returns:\n", - " bool: True if the parameter has the specified property, False otherwise.\n", - " \"\"\"\n", - " parameters = api_description.get(\"parameters\", {}).get(\"properties\", {})\n", - " parameter_info = parameters.get(parameter_name, {})\n", - "\n", - " return property_name in parameter_info\n", - "\n", - "\n", - "# Example usage\n", - "\n", - "def convert_tools(tools: List[Dict[str, Any]]):\n", - " return \"\\n\".join([json.dumps(tool) for tool in tools])\n", - "\n", - "# Helper function to create the system prompt for our model\n", - "def format_prompt(tools: List[Dict[str, Any]]):\n", - " tool_text = convert_tools(tools)\n", - "\n", - " return (\n", - " TASK_PROMPT\n", - " + \"\\n\\n\"\n", - " + TOOL_PROMPT.format(tool_text=tool_text)\n", - " + \"\\n\\n\"\n", - " + FORMAT_PROMPT\n", - " + \"\\n\"\n", - " )\n", - "\n", - "openai_format_tools = [get_weather_api]\n", - "\n", - "system_prompt = format_prompt(openai_format_tools)\n", - "\n", - "\n", - "from openai import OpenAI\n", - "\n", - "client = OpenAI(base_url=\"https://api.fc.archgw.com/v1\", api_key=\"EMPTY\")\n", - "\n", - "# List models API\n", - "model = client.models.list().data[0].id\n", - "print(\"Model:\", model)\n", - "messages = [\n", - " {\"role\": \"system\", \"content\": system_prompt},\n", - " # {\"role\": \"user\", \"content\": \"can you help me check weather?\"},\n", - " {\"role\": \"user\", \"content\": \"How is the weather in Seattle in 7 days?\"},\n", - " #{\"role\": \"assistant\", \"content\": \"Of course!\"},\n", - " # {\"role\": \"user\", \"content\": \"Seattle please\"}\n", - "]\n", - "\n", - "extra_body = {\n", - " \"temperature\": 0.6,\n", - " \"top_p\": 1.0,\n", - " \"top_k\": 50,\n", - " # \"continue_final_message\": True,\n", - " # \"add_generation_prompt\": False,\n", - " \"logprobs\": True,\n", - " \"top_logprobs\": 10\n", - "}\n", - "\n", - "resp = client.chat.completions.create(\n", - " model=\"Arch-Function\",\n", - " messages=messages,\n", - " extra_body=extra_body,\n", - " stream = True\n", - ")\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n", - "The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n" - ] - } - ], - "source": [ - "import json\n", - "from app.function_calling.hallucination_handler import HallucinationStateHandler\n", - "import pytest\n", - "import os" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "get_weather_api = {\n", - " \"type\": \"function\",\n", - " \"function\": {\n", - " \"name\": \"get_current_weather\",\n", - " \"description\": \"Get current weather at a location.\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"location\": {\n", - " \"type\": \"str\",\n", - " \"description\": \"The location to get the weather for\",\n", - " \"format\": \"City, State\",\n", - " },\n", - " \"unit\": {\n", - " \"type\": \"str\",\n", - " \"description\": \"The unit to return the weather in.\",\n", - " \"enum\": [\"celsius\", \"fahrenheit\"],\n", - " \"default\": \"celsius\",\n", - " },\n", - " \"days\": {\n", - " \"type\": \"str\",\n", - " \"description\": \"the number of days for the request.\",\n", - " },\n", - " },\n", - " \"required\": [\"location\", \"days\"],\n", - " },\n", - " },\n", - "}\n", - "function_description = get_weather_api[\"function\"]\n", - "if type(function_description) != list:\n", - " function_description = [get_weather_api[\"function\"]]\n", - "\n", - "hallu = HallucinationStateHandler(response_iterator=resp, apis = function_description)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "False\n" - ] - } - ], - "source": [ - "r = next(hallu.response_iterator)\n", - "if hasattr(r.choices[0].delta, \"content\"):\n", - " token_content = r.choices[0].delta.content\n", - " if token_content:\n", - " logprobs = [p.logprob for p in r.choices[0].logprobs.content[0].top_logprobs]\n", - " print(hallu.check_token_hallucination(token_content, logprobs), hallu.tokens)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Token: None\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \n", - "\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: {\"\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: name\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \":\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "function name entered\n", - "Token: \"\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: get\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: _current\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: _weather\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \",\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \"\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: arguments\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \":\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: {\"\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: location\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \":\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \"\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: Seattle\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: ,\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: WA\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: ,\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: USA\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \",\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \"\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: days\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \":\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \"\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: 7\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \"}}\n", - "\n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: \n", - "Hallucination: False\n", - "Hallucination Message: \n", - "Token: None\n", - "Hallucination: False\n", - "Hallucination Message: \n" - ] - } - ], - "source": [ - "for token in hallu:\n", - " print(f\"Token: {token}\")\n", - " print(f\"Hallucination: {hallu.hallucination}\")\n", - " print(f\"Hallucination Message: {hallu.hallucination_message}\")\n", - "\n", - "# Access the tokens processed" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "for token in resp:\n", - " if hasattr(token.choices[0].delta, \"content\"):\n", - " \n", - " token_content = token.choices[0].delta.content\n", - " print(token_content)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[('', 0.0034013802651315928, -0.0032856864854693413),\n", - " ('Seattle', 1.1522446584422141e-05, -1.1521117812662851e-05),\n", - " ('7', 1.553274159959983e-05, -1.5530327800661325e-05)]" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "hallu.token_probs_map" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}