Refacotr model configs

This commit is contained in:
Shuguang Chen 2024-12-08 15:43:19 -08:00
parent 320f4612b8
commit 95e167c2f6
12 changed files with 1144 additions and 206 deletions

View file

@ -8,8 +8,8 @@ from src.commons.utils import (
wait_for_health_check,
check_lsof,
install_lsof,
find_process_by_port,
kill_process_by_port,
find_processes_by_port,
kill_processes,
)
@ -23,7 +23,7 @@ def start_server(port=51000):
"python",
"-m",
"uvicorn",
"app.main:app",
"src.main:app",
"--host",
"0.0.0.0",
"--port",
@ -56,14 +56,16 @@ def stop_server(port=51000, wait=True, timeout=10):
sys.exit(1)
logger.info(f"Stopping processes on port {port}...")
port_processes = find_process_by_port(port)
port_processes = find_processes_by_port(port)
if port_processes is None:
logger.info(f"No processes found listening on port {port}.")
else:
if len(port_processes):
process_killed = kill_process_by_port(port_processes, wait, timeout)
process_killed = kill_processes(port_processes, wait, timeout)
if not process_killed:
logger.error(f"Unable to kill all processes on {port}")
else:
logger.info(f"All processes on port {port} have been killed.")
else:
logger.error(f"Unable to find processes on {port}")

View file

@ -1,79 +0,0 @@
# ========================== Arch-Intent Default Params ==========================
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
ARCH_INTENT_INSTRUCTION = "Are there any tools can help?"
ARCH_INTENT_TASK_PROMPT = """
You are a helpful assistant.
"""
ARCH_INTENT_TOOL_PROMPT_TEMPLATE = """
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
<tools>
{tool_text}
</tools>
"""
ARCH_INTENT_FORMAT_PROMPT = """
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
- First line must read 'Yes' or 'No'.
- If yes, a second line must include a comma-separated list of tool indexes.
"""
ARCH_INTENT_GENERATION_CONFIG = {
"generation_params": {"max_tokens": 1, "stop_token_ids": [151645]}
}
# ========================== Arch-Function Default Params ==========================
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
ARCH_FUNCTION_TASK_PROMPT = """
You are a helpful assistant.
"""
ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE = """
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
"""
ARCH_FUNCTION_FORMAT_PROMPT = """
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
"""
ARCH_FUNCTION_GENERATION_CONFIG = {
"generation_params": {
"temperature": 0.2,
"top_p": 1.0,
"top_k": 50,
"max_tokens": 512,
"stop_token_ids": [151645],
},
"prefill_params": {
"continue_final_message": True,
"add_generation_prompt": False,
},
"prefill_prefix": [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
],
}

View file

@ -1,36 +1,36 @@
from openai import OpenAI
from src.commons.constants import *
from src.core.function_calling import ArchIntentHandler, ArchFunctionHandler
from src.core.guardrails import get_guardrail_handler
from src.commons.utils import get_model_server_logger
from src.core.guardrails import get_guardrail_handler
from src.core.function_calling import (
ArchIntentConfig,
ArchIntentHandler,
ArchFunctionConfig,
ArchFunctionHandler,
)
# Define logger
logger = get_model_server_logger()
# Define the client
ARCH_ENDPOINT = "https://api.fc.archgw.com/v1"
ARCH_API_KEY = "EMPTY"
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
# Define model names
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
# Define model handlers
handler_map = {
"Arch-Intent": ArchIntentHandler(
ARCH_CLIENT,
ARCH_INTENT_MODEL_ALIAS,
ARCH_INTENT_TASK_PROMPT,
ARCH_INTENT_TOOL_PROMPT_TEMPLATE,
ARCH_INTENT_FORMAT_PROMPT,
ARCH_INTENT_INSTRUCTION,
**ARCH_INTENT_GENERATION_CONFIG,
ARCH_CLIENT, ARCH_INTENT_MODEL_ALIAS, ArchIntentConfig
),
"Arch-Function": ArchFunctionHandler(
ARCH_CLIENT,
ARCH_FUNCTION_MODEL_ALIAS,
ARCH_FUNCTION_TASK_PROMPT,
ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE,
ARCH_FUNCTION_FORMAT_PROMPT,
**ARCH_FUNCTION_GENERATION_CONFIG,
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
),
"Arch-Guard": get_guardrail_handler(),
}

View file

@ -12,7 +12,7 @@ PROJ_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fil
# Default log directory and file
DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, "logs")
DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, ".logs")
DEFAULT_LOG_FILE = "modelserver.log"
@ -50,7 +50,7 @@ def get_model_server_logger(log_dir=None, log_file=None):
Get or initialize the logger instance for the model server.
Parameters:
- log_dir (str): Custom directory to store the log file. Defaults to `~/archgw_logs`.
- log_dir (str): Custom directory to store the log file. Defaults to `./.logs`.
- log_file (str): Custom log file name. Defaults to `modelserver.log`.
Returns:
@ -146,13 +146,13 @@ def terminate_process_by_pid(pid, timeout):
subprocess.run(["kill", "-9", str(pid)], check=False)
def find_process_by_port(port=51000):
def find_processes_by_port(port=51000):
"""Find processes listening on a specific port."""
port_processes = []
try:
lsof_command = f"lsof -n -i:{port} | grep LISTEN"
lsof_command = f"lsof -n | grep {port} | grep -i LISTEN"
result = subprocess.run(
lsof_command, shell=True, capture_output=True, text=True
)
@ -167,7 +167,7 @@ def find_process_by_port(port=51000):
return []
def kill_process_by_port(port_processes=51000, wait=True, timeout=10):
def kill_processes(port_processes, wait=True, timeout=10):
"""Kill processes on a specific port."""
try:

View file

@ -1,11 +1,12 @@
import json
import random
import builtins
import textwrap
from openai import OpenAI
from typing import Any, Dict, List, Tuple, Union
from overrides import override
from src.core.base_handler import (
from src.core.model_utils import (
Message,
ChatMessage,
Choice,
@ -14,43 +15,57 @@ from src.core.base_handler import (
)
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
class ArchIntentConfig:
TASK_PROMPT = textwrap.dedent(
"""
You are a helpful assistant.
"""
).strip()
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
"""
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
<tools>
{tool_text}
</tools>
"""
).strip()
FORMAT_PROMPT = textwrap.dedent(
"""
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
- First line must read 'Yes' or 'No'.
- If yes, a second line must include a comma-separated list of tool indexes.
"""
).strip()
EXTRA_INSTRUCTION = "Are there any tools can help?"
GENERATION_PARAMS = {"max_tokens": 1, "stop_token_ids": [151645]}
class ArchIntentHandler(ArchBaseHandler):
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt_template: str,
format_prompt: str,
extra_instruction: str,
generation_params: Dict,
):
def __init__(self, client: OpenAI, model_name: str, config: ArchIntentConfig):
"""
Initializes the intent handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt_template (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
extra_instruction (str): Instructions specific to intent handling.
generation_params (Dict): Generation parameters for the model.
config (ArchIntentConfig): The configuration for Arch-Intent.
"""
super().__init__(
client,
model_name,
task_prompt,
tool_prompt_template,
format_prompt,
generation_params,
config.TASK_PROMPT,
config.TOOL_PROMPT_TEMPLATE,
config.FORMAT_PROMPT,
config.GENERATION_PARAMS,
)
self.extra_instruction = extra_instruction
self.extra_instruction = config.EXTRA_INSTRUCTION
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
@ -125,17 +140,73 @@ class ArchIntentHandler(ArchBaseHandler):
return chat_completion_response
# =============================================================================================================
class ArchFunctionConfig:
TASK_PROMPT = textwrap.dedent(
"""
You are a helpful assistant.
"""
).strip()
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
"""
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
"""
).strip()
FORMAT_PROMPT = textwrap.dedent(
"""
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
"""
).strip()
GENERATION_PARAMS = (
{
"temperature": 0.2,
"top_p": 1.0,
"top_k": 50,
"max_tokens": 512,
"stop_token_ids": [151645],
},
)
PREFILL_CONFIG = {
"prefill_params": {
"continue_final_message": True,
"add_generation_prompt": False,
},
"prefill_prefix": [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
],
}
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
class ArchFunctionHandler(ArchBaseHandler):
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt_template: str,
format_prompt: str,
generation_params: Dict,
prefill_params: Dict,
prefill_prefix: List,
config: ArchFunctionConfig,
):
"""
Initializes the function handler.
@ -143,30 +214,26 @@ class ArchFunctionHandler(ArchBaseHandler):
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt_template (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
prefill_params (Dict): Additional parameters for prefilling responses.
prefill_prefix (List[str]): List of prefixes for prefill responses.
config (ArchFunctionConfig): The configuration for Arch-Function
"""
super().__init__(
client,
model_name,
task_prompt,
tool_prompt_template,
format_prompt,
generation_params,
config.TASK_PROMPT,
config.TOOL_PROMPT_TEMPLATE,
config.FORMAT_PROMPT,
config.GENERATION_PARAMS,
)
self.prefill_params = prefill_params
self.prefill_prefix = prefill_prefix
self.prefill_params = config.PREFILL_CONFIG["prefill_params"]
self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"]
# Predefine data types for verification. Only support Python for now.
# [TODO] Extend the list of support data types
self.support_data_types = {
type_name: getattr(builtins, type_name) for type_name in SUPPORT_DATA_TYPES
type_name: getattr(builtins, type_name)
for type_name in config.SUPPORT_DATA_TYPES
}
@override

View file

@ -3,22 +3,9 @@ import torch
import numpy as np
import src.commons.utils as utils
from typing import List
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
class GuardRequest(BaseModel):
input: str
task: str
class GuardResponse(BaseModel):
prob: List
verdict: bool
sentence: List
latency: float = 0
from src.core.model_utils import GuardRequest, GuardResponse
class ArchGuardHanlder:

View file

@ -32,6 +32,21 @@ class ChatCompletionResponse(BaseModel):
model: str
class GuardRequest(BaseModel):
input: str
task: str
class GuardResponse(BaseModel):
prob: List
verdict: bool
sentence: List
latency: float = 0
# ================================================================================================
class ArchBaseHandler:
def __init__(
self,
@ -53,9 +68,7 @@ class ArchBaseHandler:
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
"""
self.client = client
self.model_name = model_name
self.task_prompt = task_prompt

View file

@ -1,8 +1,7 @@
import os
from src.commons.globals import handler_map
from src.core.base_handler import ChatMessage
from src.core.guardrails import GuardRequest
from src.core.model_utils import ChatMessage, GuardRequest
from fastapi import FastAPI, Response
from opentelemetry import trace