mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Refacotr model configs
This commit is contained in:
parent
320f4612b8
commit
95e167c2f6
12 changed files with 1144 additions and 206 deletions
|
|
@ -8,8 +8,8 @@ from src.commons.utils import (
|
|||
wait_for_health_check,
|
||||
check_lsof,
|
||||
install_lsof,
|
||||
find_process_by_port,
|
||||
kill_process_by_port,
|
||||
find_processes_by_port,
|
||||
kill_processes,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -23,7 +23,7 @@ def start_server(port=51000):
|
|||
"python",
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"app.main:app",
|
||||
"src.main:app",
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
|
|
@ -56,14 +56,16 @@ def stop_server(port=51000, wait=True, timeout=10):
|
|||
sys.exit(1)
|
||||
|
||||
logger.info(f"Stopping processes on port {port}...")
|
||||
port_processes = find_process_by_port(port)
|
||||
port_processes = find_processes_by_port(port)
|
||||
if port_processes is None:
|
||||
logger.info(f"No processes found listening on port {port}.")
|
||||
else:
|
||||
if len(port_processes):
|
||||
process_killed = kill_process_by_port(port_processes, wait, timeout)
|
||||
process_killed = kill_processes(port_processes, wait, timeout)
|
||||
if not process_killed:
|
||||
logger.error(f"Unable to kill all processes on {port}")
|
||||
else:
|
||||
logger.info(f"All processes on port {port} have been killed.")
|
||||
else:
|
||||
logger.error(f"Unable to find processes on {port}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,79 +0,0 @@
|
|||
# ========================== Arch-Intent Default Params ==========================
|
||||
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
|
||||
ARCH_INTENT_INSTRUCTION = "Are there any tools can help?"
|
||||
|
||||
ARCH_INTENT_TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
"""
|
||||
|
||||
|
||||
ARCH_INTENT_TOOL_PROMPT_TEMPLATE = """
|
||||
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
|
||||
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
"""
|
||||
|
||||
|
||||
ARCH_INTENT_FORMAT_PROMPT = """
|
||||
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
|
||||
- First line must read 'Yes' or 'No'.
|
||||
- If yes, a second line must include a comma-separated list of tool indexes.
|
||||
"""
|
||||
|
||||
|
||||
ARCH_INTENT_GENERATION_CONFIG = {
|
||||
"generation_params": {"max_tokens": 1, "stop_token_ids": [151645]}
|
||||
}
|
||||
|
||||
|
||||
# ========================== Arch-Function Default Params ==========================
|
||||
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
|
||||
|
||||
ARCH_FUNCTION_TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
"""
|
||||
|
||||
|
||||
ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE = """
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
"""
|
||||
|
||||
|
||||
ARCH_FUNCTION_FORMAT_PROMPT = """
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
"""
|
||||
|
||||
ARCH_FUNCTION_GENERATION_CONFIG = {
|
||||
"generation_params": {
|
||||
"temperature": 0.2,
|
||||
"top_p": 1.0,
|
||||
"top_k": 50,
|
||||
"max_tokens": 512,
|
||||
"stop_token_ids": [151645],
|
||||
},
|
||||
"prefill_params": {
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False,
|
||||
},
|
||||
"prefill_prefix": [
|
||||
"May",
|
||||
"Could",
|
||||
"Sure",
|
||||
"Definitely",
|
||||
"Certainly",
|
||||
"Of course",
|
||||
"Can",
|
||||
],
|
||||
}
|
||||
|
|
@ -1,36 +1,36 @@
|
|||
from openai import OpenAI
|
||||
from src.commons.constants import *
|
||||
from src.core.function_calling import ArchIntentHandler, ArchFunctionHandler
|
||||
from src.core.guardrails import get_guardrail_handler
|
||||
from src.commons.utils import get_model_server_logger
|
||||
from src.core.guardrails import get_guardrail_handler
|
||||
from src.core.function_calling import (
|
||||
ArchIntentConfig,
|
||||
ArchIntentHandler,
|
||||
ArchFunctionConfig,
|
||||
ArchFunctionHandler,
|
||||
)
|
||||
|
||||
|
||||
# Define logger
|
||||
logger = get_model_server_logger()
|
||||
|
||||
|
||||
# Define the client
|
||||
ARCH_ENDPOINT = "https://api.fc.archgw.com/v1"
|
||||
ARCH_API_KEY = "EMPTY"
|
||||
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
|
||||
|
||||
|
||||
# Define model names
|
||||
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
|
||||
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
|
||||
|
||||
|
||||
# Define model handlers
|
||||
handler_map = {
|
||||
"Arch-Intent": ArchIntentHandler(
|
||||
ARCH_CLIENT,
|
||||
ARCH_INTENT_MODEL_ALIAS,
|
||||
ARCH_INTENT_TASK_PROMPT,
|
||||
ARCH_INTENT_TOOL_PROMPT_TEMPLATE,
|
||||
ARCH_INTENT_FORMAT_PROMPT,
|
||||
ARCH_INTENT_INSTRUCTION,
|
||||
**ARCH_INTENT_GENERATION_CONFIG,
|
||||
ARCH_CLIENT, ARCH_INTENT_MODEL_ALIAS, ArchIntentConfig
|
||||
),
|
||||
"Arch-Function": ArchFunctionHandler(
|
||||
ARCH_CLIENT,
|
||||
ARCH_FUNCTION_MODEL_ALIAS,
|
||||
ARCH_FUNCTION_TASK_PROMPT,
|
||||
ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE,
|
||||
ARCH_FUNCTION_FORMAT_PROMPT,
|
||||
**ARCH_FUNCTION_GENERATION_CONFIG,
|
||||
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
|
||||
),
|
||||
"Arch-Guard": get_guardrail_handler(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ PROJ_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fil
|
|||
|
||||
|
||||
# Default log directory and file
|
||||
DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, "logs")
|
||||
DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, ".logs")
|
||||
DEFAULT_LOG_FILE = "modelserver.log"
|
||||
|
||||
|
||||
|
|
@ -50,7 +50,7 @@ def get_model_server_logger(log_dir=None, log_file=None):
|
|||
Get or initialize the logger instance for the model server.
|
||||
|
||||
Parameters:
|
||||
- log_dir (str): Custom directory to store the log file. Defaults to `~/archgw_logs`.
|
||||
- log_dir (str): Custom directory to store the log file. Defaults to `./.logs`.
|
||||
- log_file (str): Custom log file name. Defaults to `modelserver.log`.
|
||||
|
||||
Returns:
|
||||
|
|
@ -146,13 +146,13 @@ def terminate_process_by_pid(pid, timeout):
|
|||
subprocess.run(["kill", "-9", str(pid)], check=False)
|
||||
|
||||
|
||||
def find_process_by_port(port=51000):
|
||||
def find_processes_by_port(port=51000):
|
||||
"""Find processes listening on a specific port."""
|
||||
|
||||
port_processes = []
|
||||
|
||||
try:
|
||||
lsof_command = f"lsof -n -i:{port} | grep LISTEN"
|
||||
lsof_command = f"lsof -n | grep {port} | grep -i LISTEN"
|
||||
result = subprocess.run(
|
||||
lsof_command, shell=True, capture_output=True, text=True
|
||||
)
|
||||
|
|
@ -167,7 +167,7 @@ def find_process_by_port(port=51000):
|
|||
return []
|
||||
|
||||
|
||||
def kill_process_by_port(port_processes=51000, wait=True, timeout=10):
|
||||
def kill_processes(port_processes, wait=True, timeout=10):
|
||||
"""Kill processes on a specific port."""
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import json
|
||||
import random
|
||||
import builtins
|
||||
import textwrap
|
||||
|
||||
from openai import OpenAI
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from overrides import override
|
||||
from src.core.base_handler import (
|
||||
from src.core.model_utils import (
|
||||
Message,
|
||||
ChatMessage,
|
||||
Choice,
|
||||
|
|
@ -14,43 +15,57 @@ from src.core.base_handler import (
|
|||
)
|
||||
|
||||
|
||||
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
||||
class ArchIntentConfig:
|
||||
TASK_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
|
||||
"""
|
||||
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
|
||||
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
FORMAT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
|
||||
- First line must read 'Yes' or 'No'.
|
||||
- If yes, a second line must include a comma-separated list of tool indexes.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
EXTRA_INSTRUCTION = "Are there any tools can help?"
|
||||
|
||||
GENERATION_PARAMS = {"max_tokens": 1, "stop_token_ids": [151645]}
|
||||
|
||||
|
||||
class ArchIntentHandler(ArchBaseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt_template: str,
|
||||
format_prompt: str,
|
||||
extra_instruction: str,
|
||||
generation_params: Dict,
|
||||
):
|
||||
def __init__(self, client: OpenAI, model_name: str, config: ArchIntentConfig):
|
||||
"""
|
||||
Initializes the intent handler.
|
||||
|
||||
Args:
|
||||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
task_prompt (str): The main task prompt for the system.
|
||||
tool_prompt_template (str): A prompt to describe tools.
|
||||
format_prompt (str): A prompt specifying the desired output format.
|
||||
extra_instruction (str): Instructions specific to intent handling.
|
||||
generation_params (Dict): Generation parameters for the model.
|
||||
config (ArchIntentConfig): The configuration for Arch-Intent.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
client,
|
||||
model_name,
|
||||
task_prompt,
|
||||
tool_prompt_template,
|
||||
format_prompt,
|
||||
generation_params,
|
||||
config.TASK_PROMPT,
|
||||
config.TOOL_PROMPT_TEMPLATE,
|
||||
config.FORMAT_PROMPT,
|
||||
config.GENERATION_PARAMS,
|
||||
)
|
||||
|
||||
self.extra_instruction = extra_instruction
|
||||
self.extra_instruction = config.EXTRA_INSTRUCTION
|
||||
|
||||
@override
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||
|
|
@ -125,17 +140,73 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
return chat_completion_response
|
||||
|
||||
|
||||
# =============================================================================================================
|
||||
|
||||
|
||||
class ArchFunctionConfig:
|
||||
TASK_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
|
||||
"""
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
FORMAT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
GENERATION_PARAMS = (
|
||||
{
|
||||
"temperature": 0.2,
|
||||
"top_p": 1.0,
|
||||
"top_k": 50,
|
||||
"max_tokens": 512,
|
||||
"stop_token_ids": [151645],
|
||||
},
|
||||
)
|
||||
|
||||
PREFILL_CONFIG = {
|
||||
"prefill_params": {
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False,
|
||||
},
|
||||
"prefill_prefix": [
|
||||
"May",
|
||||
"Could",
|
||||
"Sure",
|
||||
"Definitely",
|
||||
"Certainly",
|
||||
"Of course",
|
||||
"Can",
|
||||
],
|
||||
}
|
||||
|
||||
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
||||
|
||||
|
||||
class ArchFunctionHandler(ArchBaseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt_template: str,
|
||||
format_prompt: str,
|
||||
generation_params: Dict,
|
||||
prefill_params: Dict,
|
||||
prefill_prefix: List,
|
||||
config: ArchFunctionConfig,
|
||||
):
|
||||
"""
|
||||
Initializes the function handler.
|
||||
|
|
@ -143,30 +214,26 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
Args:
|
||||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
task_prompt (str): The main task prompt for the system.
|
||||
tool_prompt_template (str): A prompt to describe tools.
|
||||
format_prompt (str): A prompt specifying the desired output format.
|
||||
generation_params (Dict): Generation parameters for the model.
|
||||
prefill_params (Dict): Additional parameters for prefilling responses.
|
||||
prefill_prefix (List[str]): List of prefixes for prefill responses.
|
||||
config (ArchFunctionConfig): The configuration for Arch-Function
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
client,
|
||||
model_name,
|
||||
task_prompt,
|
||||
tool_prompt_template,
|
||||
format_prompt,
|
||||
generation_params,
|
||||
config.TASK_PROMPT,
|
||||
config.TOOL_PROMPT_TEMPLATE,
|
||||
config.FORMAT_PROMPT,
|
||||
config.GENERATION_PARAMS,
|
||||
)
|
||||
|
||||
self.prefill_params = prefill_params
|
||||
self.prefill_prefix = prefill_prefix
|
||||
self.prefill_params = config.PREFILL_CONFIG["prefill_params"]
|
||||
self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"]
|
||||
|
||||
# Predefine data types for verification. Only support Python for now.
|
||||
# [TODO] Extend the list of support data types
|
||||
self.support_data_types = {
|
||||
type_name: getattr(builtins, type_name) for type_name in SUPPORT_DATA_TYPES
|
||||
type_name: getattr(builtins, type_name)
|
||||
for type_name in config.SUPPORT_DATA_TYPES
|
||||
}
|
||||
|
||||
@override
|
||||
|
|
|
|||
|
|
@ -3,22 +3,9 @@ import torch
|
|||
import numpy as np
|
||||
import src.commons.utils as utils
|
||||
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
|
||||
|
||||
class GuardRequest(BaseModel):
|
||||
input: str
|
||||
task: str
|
||||
|
||||
|
||||
class GuardResponse(BaseModel):
|
||||
prob: List
|
||||
verdict: bool
|
||||
sentence: List
|
||||
latency: float = 0
|
||||
from src.core.model_utils import GuardRequest, GuardResponse
|
||||
|
||||
|
||||
class ArchGuardHanlder:
|
||||
|
|
|
|||
|
|
@ -32,6 +32,21 @@ class ChatCompletionResponse(BaseModel):
|
|||
model: str
|
||||
|
||||
|
||||
class GuardRequest(BaseModel):
|
||||
input: str
|
||||
task: str
|
||||
|
||||
|
||||
class GuardResponse(BaseModel):
|
||||
prob: List
|
||||
verdict: bool
|
||||
sentence: List
|
||||
latency: float = 0
|
||||
|
||||
|
||||
# ================================================================================================
|
||||
|
||||
|
||||
class ArchBaseHandler:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -53,9 +68,7 @@ class ArchBaseHandler:
|
|||
format_prompt (str): A prompt specifying the desired output format.
|
||||
generation_params (Dict): Generation parameters for the model.
|
||||
"""
|
||||
|
||||
self.client = client
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
self.task_prompt = task_prompt
|
||||
|
|
@ -1,8 +1,7 @@
|
|||
import os
|
||||
|
||||
from src.commons.globals import handler_map
|
||||
from src.core.base_handler import ChatMessage
|
||||
from src.core.guardrails import GuardRequest
|
||||
from src.core.model_utils import ChatMessage, GuardRequest
|
||||
|
||||
from fastapi import FastAPI, Response
|
||||
from opentelemetry import trace
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue