diff --git a/model_server/app/tests/test.ipynb b/model_server/app/tests/test.ipynb
deleted file mode 100644
index 1dfd455e..00000000
--- a/model_server/app/tests/test.ipynb
+++ /dev/null
@@ -1,513 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/co-tran/Documents/arch/new_venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Model: Arch-Function\n"
- ]
- }
- ],
- "source": [
- "import pandas as pd\n",
- "import numpy as np\n",
- "import json\n",
- "import ast\n",
- "from datasets import load_dataset\n",
- "import re\n",
- "import subprocess\n",
- "import threading\n",
- "import requests\n",
- "import time\n",
- "from concurrent.futures import ThreadPoolExecutor\n",
- "from tqdm import tqdm\n",
- "import os\n",
- "import json\n",
- "import math\n",
- "import torch\n",
- "\n",
- "from datasets import load_dataset\n",
- "from typing import Any, Dict, List, Tuple\n",
- "\n",
- "from typing import Any, Dict, List\n",
- "\n",
- "def extract_tool_calls(content: str):\n",
- " tool_calls = []\n",
- "\n",
- " flag = False\n",
- " for line in content.split(\"\\n\"):\n",
- " if \"\" == line:\n",
- " flag = True\n",
- " elif \"\" == line:\n",
- " flag = False\n",
- " else:\n",
- " if flag:\n",
- " try:\n",
- " tool_content = json.loads(line)\n",
- " except Exception as e:\n",
- " print(e)\n",
- " fixed_content = fix_json_string(line)\n",
- " try:\n",
- " tool_content = json.loads(fixed_content)\n",
- " except json.JSONDecodeError:\n",
- " print(\"json.JSONDecodeError\")\n",
- " return content\n",
- "\n",
- " tool_calls.append(\n",
- " {\n",
- " \"id\": f\"call_{random.randint(1000, 10000)}\",\n",
- " \"type\": \"function\",\n",
- " \"function\": {\n",
- " \"name\": tool_content[\"name\"],\n",
- " \"arguments\": tool_content[\"arguments\"],\n",
- " },\n",
- " }\n",
- " )\n",
- "\n",
- " flag = False\n",
- "\n",
- " return tool_calls\n",
- "\n",
- "def fix_json_string(json_str: str):\n",
- " # Remove any leading or trailing whitespace or newline characters\n",
- " json_str = json_str.strip()\n",
- "\n",
- " # Stack to keep track of brackets\n",
- " stack = []\n",
- "\n",
- " # Clean string to collect valid characters\n",
- " fixed_str = \"\"\n",
- "\n",
- " # Dictionary for matching brackets\n",
- " matching_bracket = {\")\": \"(\", \"}\": \"{\", \"]\": \"[\"}\n",
- "\n",
- " # Dictionary for the opposite of matching_bracket\n",
- " opening_bracket = {v: k for k, v in matching_bracket.items()}\n",
- "\n",
- " for char in json_str:\n",
- " if char in \"{[(\":\n",
- " stack.append(char)\n",
- " fixed_str += char\n",
- " elif char in \"}])\":\n",
- " if stack and stack[-1] == matching_bracket[char]:\n",
- " stack.pop()\n",
- " fixed_str += char\n",
- " else:\n",
- " # Ignore the unmatched closing brackets\n",
- " continue\n",
- " else:\n",
- " fixed_str += char\n",
- "\n",
- " # If there are unmatched opening brackets left in the stack, add corresponding closing brackets\n",
- " while stack:\n",
- " unmatched_opening = stack.pop()\n",
- " fixed_str += opening_bracket[unmatched_opening]\n",
- "\n",
- " # Attempt to parse the corrected string to ensure it’s valid JSON\n",
- " return fixed_str.replace(\"\\'\", \"\\\"\")\n",
- "\n",
- "\n",
- "TASK_PROMPT = \"\"\"\n",
- "You are a helpful assistant.\n",
- "\"\"\".strip()\n",
- "\n",
- "TOOL_PROMPT = \"\"\"\n",
- "# Tools\n",
- "\n",
- "You may call one or more functions to assist with the user query.\n",
- "\n",
- "You are provided with function signatures within XML tags:\n",
- "\n",
- "{tool_text}\n",
- "\n",
- "\"\"\".strip()\n",
- "\n",
- "FORMAT_PROMPT = \"\"\"\n",
- "For each function call, return a json object with function name and arguments within XML tags:\n",
- "\n",
- "{\"name\": , \"arguments\": }\n",
- "\n",
- "\"\"\".strip()\n",
- "\n",
- "\n",
- "get_weather_api = {\n",
- "\t\"type\": \"function\",\n",
- "\t\"function\": {\n",
- "\t\t\"name\": \"get_current_weather\",\n",
- "\t\t\"description\": \"Get current weather at a location.\",\n",
- "\t\t\"parameters\": {\n",
- "\t\t\t\"type\": \"object\",\n",
- "\t\t\t\"properties\": {\n",
- "\t\t\t\t\"location\": {\n",
- "\t\t\t\t\t\"type\": \"str\",\n",
- "\t\t\t\t\t\"description\": \"The location to get the weather for\",\n",
- " \"format\": \"City, State, Country\",\n",
- "\t\t\t\t},\n",
- "\t\t\t\t\"unit\": {\n",
- "\t\t\t\t\t\"type\": \"str\",\n",
- " \"description\": \"The unit to return the weather in.\",\n",
- "\t\t\t\t\t\"enum\": [\n",
- "\t\t\t\t\t\t\"celsius\",\n",
- "\t\t\t\t\t\t\"fahrenheit\"\n",
- "\t\t\t\t\t],\n",
- " \"default\": \"celsius\"\n",
- "\t\t\t\t},\n",
- " \"days\": {\n",
- "\t\t\t\t\t\"type\": \"str\",\n",
- " \"description\": \"the number of days for the request.\",\n",
- "\t\t\t\t}\n",
- "\t\t\t},\n",
- "\t\t\t\"required\": [\n",
- "\t\t\t\t\"location\",\n",
- " \"days\"\n",
- "\t\t\t]\n",
- "\t\t}\n",
- "\t}\n",
- "}\n",
- "def check_parameter_property(api_description, parameter_name, property_name):\n",
- " \"\"\"\n",
- " Check if a parameter in an API description has a specific property.\n",
- "\n",
- " Args:\n",
- " api_description (dict): The API description in JSON format.\n",
- " parameter_name (str): The name of the parameter to check.\n",
- " property_name (str): The property to look for (e.g., 'format', 'default').\n",
- "\n",
- " Returns:\n",
- " bool: True if the parameter has the specified property, False otherwise.\n",
- " \"\"\"\n",
- " parameters = api_description.get(\"parameters\", {}).get(\"properties\", {})\n",
- " parameter_info = parameters.get(parameter_name, {})\n",
- "\n",
- " return property_name in parameter_info\n",
- "\n",
- "\n",
- "# Example usage\n",
- "\n",
- "def convert_tools(tools: List[Dict[str, Any]]):\n",
- " return \"\\n\".join([json.dumps(tool) for tool in tools])\n",
- "\n",
- "# Helper function to create the system prompt for our model\n",
- "def format_prompt(tools: List[Dict[str, Any]]):\n",
- " tool_text = convert_tools(tools)\n",
- "\n",
- " return (\n",
- " TASK_PROMPT\n",
- " + \"\\n\\n\"\n",
- " + TOOL_PROMPT.format(tool_text=tool_text)\n",
- " + \"\\n\\n\"\n",
- " + FORMAT_PROMPT\n",
- " + \"\\n\"\n",
- " )\n",
- "\n",
- "openai_format_tools = [get_weather_api]\n",
- "\n",
- "system_prompt = format_prompt(openai_format_tools)\n",
- "\n",
- "\n",
- "from openai import OpenAI\n",
- "\n",
- "client = OpenAI(base_url=\"https://api.fc.archgw.com/v1\", api_key=\"EMPTY\")\n",
- "\n",
- "# List models API\n",
- "model = client.models.list().data[0].id\n",
- "print(\"Model:\", model)\n",
- "messages = [\n",
- " {\"role\": \"system\", \"content\": system_prompt},\n",
- " # {\"role\": \"user\", \"content\": \"can you help me check weather?\"},\n",
- " {\"role\": \"user\", \"content\": \"How is the weather in Seattle in 7 days?\"},\n",
- " #{\"role\": \"assistant\", \"content\": \"Of course!\"},\n",
- " # {\"role\": \"user\", \"content\": \"Seattle please\"}\n",
- "]\n",
- "\n",
- "extra_body = {\n",
- " \"temperature\": 0.6,\n",
- " \"top_p\": 1.0,\n",
- " \"top_k\": 50,\n",
- " # \"continue_final_message\": True,\n",
- " # \"add_generation_prompt\": False,\n",
- " \"logprobs\": True,\n",
- " \"top_logprobs\": 10\n",
- "}\n",
- "\n",
- "resp = client.chat.completions.create(\n",
- " model=\"Arch-Function\",\n",
- " messages=messages,\n",
- " extra_body=extra_body,\n",
- " stream = True\n",
- ")\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n",
- "The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.\n"
- ]
- }
- ],
- "source": [
- "import json\n",
- "from app.function_calling.hallucination_handler import HallucinationStateHandler\n",
- "import pytest\n",
- "import os"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "get_weather_api = {\n",
- " \"type\": \"function\",\n",
- " \"function\": {\n",
- " \"name\": \"get_current_weather\",\n",
- " \"description\": \"Get current weather at a location.\",\n",
- " \"parameters\": {\n",
- " \"type\": \"object\",\n",
- " \"properties\": {\n",
- " \"location\": {\n",
- " \"type\": \"str\",\n",
- " \"description\": \"The location to get the weather for\",\n",
- " \"format\": \"City, State\",\n",
- " },\n",
- " \"unit\": {\n",
- " \"type\": \"str\",\n",
- " \"description\": \"The unit to return the weather in.\",\n",
- " \"enum\": [\"celsius\", \"fahrenheit\"],\n",
- " \"default\": \"celsius\",\n",
- " },\n",
- " \"days\": {\n",
- " \"type\": \"str\",\n",
- " \"description\": \"the number of days for the request.\",\n",
- " },\n",
- " },\n",
- " \"required\": [\"location\", \"days\"],\n",
- " },\n",
- " },\n",
- "}\n",
- "function_description = get_weather_api[\"function\"]\n",
- "if type(function_description) != list:\n",
- " function_description = [get_weather_api[\"function\"]]\n",
- "\n",
- "hallu = HallucinationStateHandler(response_iterator=resp, apis = function_description)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "False\n"
- ]
- }
- ],
- "source": [
- "r = next(hallu.response_iterator)\n",
- "if hasattr(r.choices[0].delta, \"content\"):\n",
- " token_content = r.choices[0].delta.content\n",
- " if token_content:\n",
- " logprobs = [p.logprob for p in r.choices[0].logprobs.content[0].top_logprobs]\n",
- " print(hallu.check_token_hallucination(token_content, logprobs), hallu.tokens)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Token: None\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \n",
- "\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: {\"\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: name\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \":\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "function name entered\n",
- "Token: \"\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: get\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: _current\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: _weather\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \",\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \"\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: arguments\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \":\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: {\"\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: location\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \":\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \"\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: Seattle\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: ,\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: WA\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: ,\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: USA\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \",\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \"\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: days\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \":\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \"\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: 7\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \"}}\n",
- "\n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: \n",
- "Hallucination: False\n",
- "Hallucination Message: \n",
- "Token: None\n",
- "Hallucination: False\n",
- "Hallucination Message: \n"
- ]
- }
- ],
- "source": [
- "for token in hallu:\n",
- " print(f\"Token: {token}\")\n",
- " print(f\"Hallucination: {hallu.hallucination}\")\n",
- " print(f\"Hallucination Message: {hallu.hallucination_message}\")\n",
- "\n",
- "# Access the tokens processed"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [],
- "source": [
- "for token in resp:\n",
- " if hasattr(token.choices[0].delta, \"content\"):\n",
- " \n",
- " token_content = token.choices[0].delta.content\n",
- " print(token_content)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[('', 0.0034013802651315928, -0.0032856864854693413),\n",
- " ('Seattle', 1.1522446584422141e-05, -1.1521117812662851e-05),\n",
- " ('7', 1.553274159959983e-05, -1.5530327800661325e-05)]"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "hallu.token_probs_map"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "language_info": {
- "name": "python"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}