model server build (#127)

* first commit to have model_server not be dependent on Docker

* making changes to fix the docker-compose file for archgw to set DNS_V4 and minor fixes with the build

* additional fixes for model server to be separated out in the build

* additional fixes for model server to be separated out in the build

* fix to get model_server to be built as a separate python process. TODO: fix the embeddings logs after cli completes

* fixing init to pull tempfile using the tempfile python package

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-261.local>
This commit is contained in:
Salman Paracha 2024-10-06 18:21:43 -07:00 committed by GitHub
parent 7d21359f5b
commit b60ceb9168
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 3390 additions and 154 deletions

View file

@ -0,0 +1,99 @@
import sys
import subprocess
import os
import signal
import time
import requests
import psutil
import tempfile
# Path to the file where the server process ID will be stored
PID_FILE = os.path.join(tempfile.gettempdir(), "model_server.pid")
def run_server():
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
if len(sys.argv) > 1:
action = sys.argv[1]
else:
action = "start"
if action == "start":
start_server()
elif action == "stop":
stop_server()
elif action == "restart":
restart_server()
else:
print(f"Unknown action: {action}")
sys.exit(1)
def start_server():
"""Start the Uvicorn server and save the process ID."""
if os.path.exists(PID_FILE):
print("Server is already running. Use 'model_server restart' to restart it.")
sys.exit(1)
print(f"Starting Archgw Model Server")
process = subprocess.Popen(
["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"],
)
if wait_for_health_check("http://0.0.0.0:51000/healthz"):
# Write the process ID to the PID file
with open(PID_FILE, "w") as f:
f.write(str(process.pid))
print(f"ARCH GW Model Server started with PID {process.pid}")
else:
#Add model_server boot-up logs
print(f"ARCH GW Model Server - Didn't Sart In Time. Shutting Down")
process.terminate()
def wait_for_health_check(url, timeout=180):
"""Wait for the Uvicorn server to respond to health-check requests."""
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(url)
if response.status_code == 200:
return True
except requests.ConnectionError:
time.sleep(1)
print("Timed out waiting for ARCH GW Model Server to respond.")
return False
def stop_server():
"""Stop the running Uvicorn server."""
if not os.path.exists(PID_FILE):
print("Status: Archgw Model Server not running")
return
# Read the process ID from the PID file
with open(PID_FILE, "r") as f:
pid = int(f.read())
try:
# Get process by PID
process = psutil.Process(pid)
# Gracefully terminate the process
process.terminate() # Sends SIGTERM by default
process.wait(timeout=10) # Wait for up to 10 seconds for the process to exit
print(f"Server with PID {pid} stopped.")
os.remove(PID_FILE)
except psutil.NoSuchProcess:
print(f"Process with PID {pid} not found. Cleaning up PID file.")
os.remove(PID_FILE)
except psutil.TimeoutExpired:
print(f"Process with PID {pid} did not terminate in time. Forcing shutdown.")
process.kill() # Forcefully kill the process
os.remove(PID_FILE)
def restart_server():
"""Restart the Uvicorn server."""
print("Check: Is Archgw Model Server running?")
stop_server()
start_server()

View file

@ -1,9 +1,10 @@
import json
import random
from fastapi import FastAPI, Response
from app.arch_fc.arch_handler import ArchHandler
from app.arch_fc.bolt_handler import BoltHandler
from app.arch_fc.common import ChatMessage, Message
from .common import ChatMessage, Message
from .arch_handler import ArchHandler
from .bolt_handler import BoltHandler
from app.utils import load_yaml_config
import logging
import yaml
from openai import OpenAI
@ -14,17 +15,14 @@ logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
with open("openai_params.yaml") as f:
params = yaml.safe_load(f)
params = load_yaml_config("openai_params.yaml")
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost")
ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M")
fc_url = os.getenv("FC_URL", ollama_endpoint)
fc_url = os.getenv("FC_URL", "https://arch-fc-free-trial-4mzywewe.uc.gateway.dev/v1")
mode = os.getenv("MODE", "cloud")
if mode not in ["cloud", "local-gpu", "local-cpu"]:
raise ValueError(f"Invalid mode: {mode}")
arch_api_key = os.getenv("ARCH_API_KEY", "vllm")
handler = None
if ollama_model.startswith("Arch"):

View file

@ -1,83 +0,0 @@
import pandas as pd
import random
import datetime
def generate_employee_data(conn):
# List of possible names, positions, departments, and locations
names = [
"Alice",
"Bob",
"Charlie",
"David",
"Eve",
"Frank",
"Grace",
"Hank",
"Ivy",
"Jack",
]
positions = [
"Manager",
"Engineer",
"Salesperson",
"HR Specialist",
"Marketing Analyst",
]
departments = ["Engineering", "Marketing", "HR", "Sales", "Finance"]
locations = ["New York", "San Francisco", "Austin", "Boston", "Chicago"]
# Function to generate random hire date
def random_hire_date():
start_date = datetime.date(2000, 1, 1)
end_date = datetime.date(2023, 12, 31)
time_between_dates = end_date - start_date
days_between_dates = time_between_dates.days
random_number_of_days = random.randrange(days_between_dates)
hire_date = start_date + datetime.timedelta(days=random_number_of_days)
return hire_date
# Function to generate random employee data
def generate_employee_records(count):
employees = []
for _ in range(count):
name = random.choice(names)
position = random.choice(positions)
salary = round(
random.uniform(50000, 150000), 2
) # Salary between 50,000 and 150,000
department = random.choice(departments)
location = random.choice(locations)
hire_date = random_hire_date()
performance_score = round(
random.uniform(1, 5), 2
) # Performance score between 1.0 and 5.0
years_of_experience = random.randint(
1, 30
) # Years of experience between 1 and 30
employee = {
"position": position,
"name": name,
"salary": salary,
"department": department,
"location": location,
"hire_date": hire_date,
"performance_score": performance_score,
"years_of_experience": years_of_experience,
}
employees.append(employee)
return employees
# Generate 10 random employee records
employee_records = generate_employee_records(200)
# Convert the list of dictionaries to a DataFrame
df = pd.DataFrame(employee_records)
df.to_sql("employees", conn, index=False)
return

View file

@ -0,0 +1,3 @@
jailbreak:
cpu: "katanemolabs/Arch-Guard-cpu"
gpu: "katanemolabs/Arch-Guard-gpu"

View file

@ -17,8 +17,6 @@ def get_device():
def load_transformers(models=os.getenv("MODELS", "BAAI/bge-large-en-v1.5")):
transformers = {}
device = get_device()
print(f"Using device: {device}")
for model in models.split(","):
transformers[model] = sentence_transformers.SentenceTransformer(model, device=device)

View file

@ -5,9 +5,10 @@ from app.load_models import (
load_transformers,
load_guard_model,
load_zero_shot_models,
get_device
)
import os
from app.utils import GuardHandler, split_text_into_chunks
from app.utils import GuardHandler, split_text_into_chunks, load_yaml_config
import torch
import yaml
import string
@ -23,9 +24,7 @@ logger = logging.getLogger(__name__)
transformers = load_transformers()
zero_shot_models = load_zero_shot_models()
with open("guard_model_config.yaml") as f:
guard_model_config = yaml.safe_load(f)
guard_model_config = load_yaml_config("guard_model_config.yaml")
mode = os.getenv("MODE", "cloud")
logger.info(f"Serving model mode: {mode}")
@ -48,12 +47,8 @@ class EmbeddingRequest(BaseModel):
@app.get("/healthz")
async def healthz():
import os
print(os.getcwd())
return {"status": "ok"}
@app.get("/models")
async def models():
models = []
@ -66,6 +61,7 @@ async def models():
@app.post("/embeddings")
async def embedding(req: EmbeddingRequest, res: Response):
print(f"Embedding Call Start Time: {time.time()}")
if req.model not in transformers:
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
@ -80,6 +76,7 @@ async def embedding(req: EmbeddingRequest, res: Response):
"prompt_tokens": 0,
"total_tokens": 0,
}
print(f"Embedding Call Complete Time: {time.time()}")
return {"data": data, "model": req.model, "object": "list", "usage": usage}

View file

@ -0,0 +1,6 @@
params:
temperature: 0.01
top_p : 0.5
top_k: 50
max_tokens: 2024
stop_token_ids: [151645, 151643]

View file

@ -2,6 +2,14 @@ import numpy as np
from concurrent.futures import ThreadPoolExecutor
import time
import torch
import pkg_resources
import yaml
def load_yaml_config(file_name):
# Load the YAML file from the package
yaml_path = pkg_resources.resource_filename('app', file_name)
with open(yaml_path, 'r') as yaml_file:
return yaml.safe_load(yaml_file)
def split_text_into_chunks(text, max_words=300):
@ -21,7 +29,6 @@ def split_text_into_chunks(text, max_words=300):
def softmax(x):
return np.exp(x) / np.exp(x).sum(axis=0)
class PredictionHandler:
def __init__(self, model, tokenizer, device, task="toxic", hardware_config="cpu"):
self.model = model