mirror of
https://github.com/katanemo/plano.git
synced 2026-06-14 15:15:15 +02:00
model server build (#127)
* first commit to have model_server not be dependent on Docker * making changes to fix the docker-compose file for archgw to set DNS_V4 and minor fixes with the build * additional fixes for model server to be separated out in the build * additional fixes for model server to be separated out in the build * fix to get model_server to be built as a separate python process. TODO: fix the embeddings logs after cli completes * fixing init to pull tempfile using the tempfile python package --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-261.local>
This commit is contained in:
parent
7d21359f5b
commit
b60ceb9168
21 changed files with 3390 additions and 154 deletions
|
|
@ -0,0 +1,99 @@
|
|||
import sys
|
||||
import subprocess
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
import requests
|
||||
import psutil
|
||||
import tempfile
|
||||
|
||||
# Path to the file where the server process ID will be stored
|
||||
PID_FILE = os.path.join(tempfile.gettempdir(), "model_server.pid")
|
||||
|
||||
def run_server():
|
||||
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
|
||||
if len(sys.argv) > 1:
|
||||
action = sys.argv[1]
|
||||
else:
|
||||
action = "start"
|
||||
|
||||
if action == "start":
|
||||
start_server()
|
||||
elif action == "stop":
|
||||
stop_server()
|
||||
elif action == "restart":
|
||||
restart_server()
|
||||
else:
|
||||
print(f"Unknown action: {action}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def start_server():
|
||||
"""Start the Uvicorn server and save the process ID."""
|
||||
if os.path.exists(PID_FILE):
|
||||
print("Server is already running. Use 'model_server restart' to restart it.")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Starting Archgw Model Server")
|
||||
process = subprocess.Popen(
|
||||
["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"],
|
||||
)
|
||||
|
||||
if wait_for_health_check("http://0.0.0.0:51000/healthz"):
|
||||
# Write the process ID to the PID file
|
||||
with open(PID_FILE, "w") as f:
|
||||
f.write(str(process.pid))
|
||||
print(f"ARCH GW Model Server started with PID {process.pid}")
|
||||
else:
|
||||
#Add model_server boot-up logs
|
||||
print(f"ARCH GW Model Server - Didn't Sart In Time. Shutting Down")
|
||||
process.terminate()
|
||||
|
||||
def wait_for_health_check(url, timeout=180):
|
||||
"""Wait for the Uvicorn server to respond to health-check requests."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.ConnectionError:
|
||||
time.sleep(1)
|
||||
print("Timed out waiting for ARCH GW Model Server to respond.")
|
||||
return False
|
||||
|
||||
|
||||
def stop_server():
|
||||
"""Stop the running Uvicorn server."""
|
||||
if not os.path.exists(PID_FILE):
|
||||
print("Status: Archgw Model Server not running")
|
||||
return
|
||||
|
||||
# Read the process ID from the PID file
|
||||
with open(PID_FILE, "r") as f:
|
||||
pid = int(f.read())
|
||||
|
||||
try:
|
||||
# Get process by PID
|
||||
process = psutil.Process(pid)
|
||||
|
||||
# Gracefully terminate the process
|
||||
process.terminate() # Sends SIGTERM by default
|
||||
process.wait(timeout=10) # Wait for up to 10 seconds for the process to exit
|
||||
|
||||
print(f"Server with PID {pid} stopped.")
|
||||
os.remove(PID_FILE)
|
||||
|
||||
except psutil.NoSuchProcess:
|
||||
print(f"Process with PID {pid} not found. Cleaning up PID file.")
|
||||
os.remove(PID_FILE)
|
||||
except psutil.TimeoutExpired:
|
||||
print(f"Process with PID {pid} did not terminate in time. Forcing shutdown.")
|
||||
process.kill() # Forcefully kill the process
|
||||
os.remove(PID_FILE)
|
||||
|
||||
def restart_server():
|
||||
"""Restart the Uvicorn server."""
|
||||
print("Check: Is Archgw Model Server running?")
|
||||
stop_server()
|
||||
start_server()
|
||||
|
|
@ -1,9 +1,10 @@
|
|||
import json
|
||||
import random
|
||||
from fastapi import FastAPI, Response
|
||||
from app.arch_fc.arch_handler import ArchHandler
|
||||
from app.arch_fc.bolt_handler import BoltHandler
|
||||
from app.arch_fc.common import ChatMessage, Message
|
||||
from .common import ChatMessage, Message
|
||||
from .arch_handler import ArchHandler
|
||||
from .bolt_handler import BoltHandler
|
||||
from app.utils import load_yaml_config
|
||||
import logging
|
||||
import yaml
|
||||
from openai import OpenAI
|
||||
|
|
@ -14,17 +15,14 @@ logging.basicConfig(
|
|||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with open("openai_params.yaml") as f:
|
||||
params = yaml.safe_load(f)
|
||||
|
||||
params = load_yaml_config("openai_params.yaml")
|
||||
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost")
|
||||
ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M")
|
||||
fc_url = os.getenv("FC_URL", ollama_endpoint)
|
||||
fc_url = os.getenv("FC_URL", "https://arch-fc-free-trial-4mzywewe.uc.gateway.dev/v1")
|
||||
|
||||
mode = os.getenv("MODE", "cloud")
|
||||
if mode not in ["cloud", "local-gpu", "local-cpu"]:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
arch_api_key = os.getenv("ARCH_API_KEY", "vllm")
|
||||
|
||||
handler = None
|
||||
if ollama_model.startswith("Arch"):
|
||||
|
|
|
|||
|
|
@ -1,83 +0,0 @@
|
|||
import pandas as pd
|
||||
import random
|
||||
import datetime
|
||||
|
||||
|
||||
def generate_employee_data(conn):
|
||||
# List of possible names, positions, departments, and locations
|
||||
names = [
|
||||
"Alice",
|
||||
"Bob",
|
||||
"Charlie",
|
||||
"David",
|
||||
"Eve",
|
||||
"Frank",
|
||||
"Grace",
|
||||
"Hank",
|
||||
"Ivy",
|
||||
"Jack",
|
||||
]
|
||||
positions = [
|
||||
"Manager",
|
||||
"Engineer",
|
||||
"Salesperson",
|
||||
"HR Specialist",
|
||||
"Marketing Analyst",
|
||||
]
|
||||
departments = ["Engineering", "Marketing", "HR", "Sales", "Finance"]
|
||||
locations = ["New York", "San Francisco", "Austin", "Boston", "Chicago"]
|
||||
|
||||
# Function to generate random hire date
|
||||
def random_hire_date():
|
||||
start_date = datetime.date(2000, 1, 1)
|
||||
end_date = datetime.date(2023, 12, 31)
|
||||
time_between_dates = end_date - start_date
|
||||
days_between_dates = time_between_dates.days
|
||||
random_number_of_days = random.randrange(days_between_dates)
|
||||
hire_date = start_date + datetime.timedelta(days=random_number_of_days)
|
||||
return hire_date
|
||||
|
||||
# Function to generate random employee data
|
||||
def generate_employee_records(count):
|
||||
employees = []
|
||||
|
||||
for _ in range(count):
|
||||
name = random.choice(names)
|
||||
position = random.choice(positions)
|
||||
salary = round(
|
||||
random.uniform(50000, 150000), 2
|
||||
) # Salary between 50,000 and 150,000
|
||||
department = random.choice(departments)
|
||||
location = random.choice(locations)
|
||||
hire_date = random_hire_date()
|
||||
performance_score = round(
|
||||
random.uniform(1, 5), 2
|
||||
) # Performance score between 1.0 and 5.0
|
||||
years_of_experience = random.randint(
|
||||
1, 30
|
||||
) # Years of experience between 1 and 30
|
||||
|
||||
employee = {
|
||||
"position": position,
|
||||
"name": name,
|
||||
"salary": salary,
|
||||
"department": department,
|
||||
"location": location,
|
||||
"hire_date": hire_date,
|
||||
"performance_score": performance_score,
|
||||
"years_of_experience": years_of_experience,
|
||||
}
|
||||
|
||||
employees.append(employee)
|
||||
|
||||
return employees
|
||||
|
||||
# Generate 10 random employee records
|
||||
employee_records = generate_employee_records(200)
|
||||
|
||||
# Convert the list of dictionaries to a DataFrame
|
||||
df = pd.DataFrame(employee_records)
|
||||
|
||||
df.to_sql("employees", conn, index=False)
|
||||
|
||||
return
|
||||
3
model_server/app/guard_model_config.yaml
Normal file
3
model_server/app/guard_model_config.yaml
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
jailbreak:
|
||||
cpu: "katanemolabs/Arch-Guard-cpu"
|
||||
gpu: "katanemolabs/Arch-Guard-gpu"
|
||||
|
|
@ -17,8 +17,6 @@ def get_device():
|
|||
def load_transformers(models=os.getenv("MODELS", "BAAI/bge-large-en-v1.5")):
|
||||
transformers = {}
|
||||
device = get_device()
|
||||
|
||||
print(f"Using device: {device}")
|
||||
for model in models.split(","):
|
||||
transformers[model] = sentence_transformers.SentenceTransformer(model, device=device)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,10 @@ from app.load_models import (
|
|||
load_transformers,
|
||||
load_guard_model,
|
||||
load_zero_shot_models,
|
||||
get_device
|
||||
)
|
||||
import os
|
||||
from app.utils import GuardHandler, split_text_into_chunks
|
||||
from app.utils import GuardHandler, split_text_into_chunks, load_yaml_config
|
||||
import torch
|
||||
import yaml
|
||||
import string
|
||||
|
|
@ -23,9 +24,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
transformers = load_transformers()
|
||||
zero_shot_models = load_zero_shot_models()
|
||||
|
||||
with open("guard_model_config.yaml") as f:
|
||||
guard_model_config = yaml.safe_load(f)
|
||||
guard_model_config = load_yaml_config("guard_model_config.yaml")
|
||||
|
||||
mode = os.getenv("MODE", "cloud")
|
||||
logger.info(f"Serving model mode: {mode}")
|
||||
|
|
@ -48,12 +47,8 @@ class EmbeddingRequest(BaseModel):
|
|||
|
||||
@app.get("/healthz")
|
||||
async def healthz():
|
||||
import os
|
||||
|
||||
print(os.getcwd())
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def models():
|
||||
models = []
|
||||
|
|
@ -66,6 +61,7 @@ async def models():
|
|||
|
||||
@app.post("/embeddings")
|
||||
async def embedding(req: EmbeddingRequest, res: Response):
|
||||
print(f"Embedding Call Start Time: {time.time()}")
|
||||
if req.model not in transformers:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
|
|
@ -80,6 +76,7 @@ async def embedding(req: EmbeddingRequest, res: Response):
|
|||
"prompt_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
print(f"Embedding Call Complete Time: {time.time()}")
|
||||
return {"data": data, "model": req.model, "object": "list", "usage": usage}
|
||||
|
||||
|
||||
|
|
|
|||
6
model_server/app/openai_params.yaml
Normal file
6
model_server/app/openai_params.yaml
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
params:
|
||||
temperature: 0.01
|
||||
top_p : 0.5
|
||||
top_k: 50
|
||||
max_tokens: 2024
|
||||
stop_token_ids: [151645, 151643]
|
||||
|
|
@ -2,6 +2,14 @@ import numpy as np
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
import torch
|
||||
import pkg_resources
|
||||
import yaml
|
||||
|
||||
def load_yaml_config(file_name):
|
||||
# Load the YAML file from the package
|
||||
yaml_path = pkg_resources.resource_filename('app', file_name)
|
||||
with open(yaml_path, 'r') as yaml_file:
|
||||
return yaml.safe_load(yaml_file)
|
||||
|
||||
|
||||
def split_text_into_chunks(text, max_words=300):
|
||||
|
|
@ -21,7 +29,6 @@ def split_text_into_chunks(text, max_words=300):
|
|||
def softmax(x):
|
||||
return np.exp(x) / np.exp(x).sum(axis=0)
|
||||
|
||||
|
||||
class PredictionHandler:
|
||||
def __init__(self, model, tokenizer, device, task="toxic", hardware_config="cpu"):
|
||||
self.model = model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue