mirror of
https://github.com/katanemo/plano.git
synced 2026-06-29 15:49:40 +02:00
model server build (#127)
* first commit to have model_server not be dependent on Docker * making changes to fix the docker-compose file for archgw to set DNS_V4 and minor fixes with the build * additional fixes for model server to be separated out in the build * additional fixes for model server to be separated out in the build * fix to get model_server to be built as a separate python process. TODO: fix the embeddings logs after cli completes * fixing init to pull tempfile using the tempfile python package --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-261.local>
This commit is contained in:
parent
7d21359f5b
commit
b60ceb9168
21 changed files with 3390 additions and 154 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -24,4 +24,8 @@ demos/network_copilot/ollama/models/
|
||||||
arch_log/
|
arch_log/
|
||||||
arch/tools/*.egg-info
|
arch/tools/*.egg-info
|
||||||
arch/tools/config
|
arch/tools/config
|
||||||
|
arch/tools/build
|
||||||
|
model_server/model_server.egg-info
|
||||||
model_server/venv_model_server
|
model_server/venv_model_server
|
||||||
|
model_server/build
|
||||||
|
model_server/dist
|
||||||
|
|
|
||||||
|
|
@ -7,24 +7,5 @@ services:
|
||||||
volumes:
|
volumes:
|
||||||
- ${ARCH_CONFIG_FILE:-./demos/function_calling/arch_confg.yaml}:/config/arch_config.yaml
|
- ${ARCH_CONFIG_FILE:-./demos/function_calling/arch_confg.yaml}:/config/arch_config.yaml
|
||||||
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
|
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
|
||||||
depends_on:
|
|
||||||
model_server:
|
|
||||||
condition: service_healthy
|
|
||||||
env_file:
|
env_file:
|
||||||
- stage.env
|
- stage.env
|
||||||
|
|
||||||
model_server:
|
|
||||||
image: model_server:latest
|
|
||||||
ports:
|
|
||||||
- "18081:80"
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl" ,"http://localhost/healthz"]
|
|
||||||
interval: 5s
|
|
||||||
retries: 20
|
|
||||||
volumes:
|
|
||||||
- ~/.cache/huggingface:/root/.cache/huggingface
|
|
||||||
environment:
|
|
||||||
- OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal}
|
|
||||||
- OLLAMA_MODEL=Arch-Function-Calling-3B-Q4_K_M
|
|
||||||
- MODE=${MODE:-cloud}
|
|
||||||
- FC_URL=${FC_URL:-https://arch-fc-free-trial-4mzywewe.uc.gateway.dev/v1}
|
|
||||||
|
|
|
||||||
|
|
@ -123,8 +123,8 @@ static_resources:
|
||||||
- endpoint:
|
- endpoint:
|
||||||
address:
|
address:
|
||||||
socket_address:
|
socket_address:
|
||||||
address: model_server
|
address: host.docker.internal
|
||||||
port_value: 80
|
port_value: 51000
|
||||||
hostname: "model_server"
|
hostname: "model_server"
|
||||||
- name: mistral_7b_instruct
|
- name: mistral_7b_instruct
|
||||||
connect_timeout: 5s
|
connect_timeout: 5s
|
||||||
|
|
@ -153,8 +153,8 @@ static_resources:
|
||||||
- endpoint:
|
- endpoint:
|
||||||
address:
|
address:
|
||||||
socket_address:
|
socket_address:
|
||||||
address: model_server
|
address: host.docker.internal
|
||||||
port_value: 80
|
port_value: 51000
|
||||||
hostname: "arch_fc"
|
hostname: "arch_fc"
|
||||||
{% for _, cluster in arch_clusters.items() %}
|
{% for _, cluster in arch_clusters.items() %}
|
||||||
- name: {{ cluster.name }}
|
- name: {{ cluster.name }}
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import config_generator
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import sys
|
import sys
|
||||||
import subprocess
|
import subprocess
|
||||||
from core import start_arch, stop_arch
|
from core import start_arch_modelserver, stop_arch_modelserver, start_arch, stop_arch
|
||||||
from utils import get_llm_provider_access_keys, load_env_file_to_dict
|
from utils import get_llm_provider_access_keys, load_env_file_to_dict
|
||||||
|
|
||||||
logo = r"""
|
logo = r"""
|
||||||
|
|
@ -26,7 +26,7 @@ def main(ctx):
|
||||||
|
|
||||||
# Command to build archgw and model_server Docker images
|
# Command to build archgw and model_server Docker images
|
||||||
ARCHGW_DOCKERFILE = "./arch/Dockerfile"
|
ARCHGW_DOCKERFILE = "./arch/Dockerfile"
|
||||||
MODEL_SERVER_DOCKERFILE = "./model_server/Dockerfile"
|
MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml"
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
def build():
|
def build():
|
||||||
|
|
@ -44,21 +44,22 @@ def build():
|
||||||
click.echo("Error: Dockerfile not found in /arch")
|
click.echo("Error: Dockerfile not found in /arch")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Check if /model_server/Dockerfile exists
|
click.echo("All images built successfully.")
|
||||||
if os.path.exists(MODEL_SERVER_DOCKERFILE):
|
|
||||||
click.echo("Building model_server image...")
|
"""Install the model server dependencies using Poetry."""
|
||||||
|
# Check if pyproject.toml exists
|
||||||
|
if os.path.exists(MODEL_SERVER_BUILD_FILE):
|
||||||
|
click.echo("Installing model server dependencies with Poetry...")
|
||||||
try:
|
try:
|
||||||
subprocess.run(["docker", "build", "-f", MODEL_SERVER_DOCKERFILE, "-t", "model_server:latest", "./model_server"], check=True)
|
subprocess.run(["poetry", "install", "--no-cache"], cwd=os.path.dirname(MODEL_SERVER_BUILD_FILE), check=True)
|
||||||
click.echo("model_server image built successfully.")
|
click.echo("Model server dependencies installed successfully.")
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
click.echo(f"Error building model_server image: {e}")
|
click.echo(f"Error installing model server dependencies: {e}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
click.echo("Error: Dockerfile not found in /model_server")
|
click.echo(f"Error: pyproject.toml not found in {MODEL_SERVER_BUILD_FILE}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
click.echo("All images built successfully.")
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.argument('file', required=False) # Optional file argument
|
@click.argument('file', required=False) # Optional file argument
|
||||||
@click.option('-path', default='.', help='Path to the directory containing arch_config.yml')
|
@click.option('-path', default='.', help='Path to the directory containing arch_config.yml')
|
||||||
|
|
@ -120,11 +121,14 @@ def up(file, path):
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env.update(env_stage)
|
env.update(env_stage)
|
||||||
env['ARCH_CONFIG_FILE'] = arch_config_file
|
env['ARCH_CONFIG_FILE'] = arch_config_file
|
||||||
|
|
||||||
|
start_arch_modelserver()
|
||||||
start_arch(arch_config_file, env)
|
start_arch(arch_config_file, env)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
def down():
|
def down():
|
||||||
"""Stops Arch."""
|
"""Stops Arch."""
|
||||||
|
stop_arch_modelserver()
|
||||||
stop_arch()
|
stop_arch()
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
|
|
|
||||||
|
|
@ -5,14 +5,13 @@ import pkg_resources
|
||||||
import select
|
import select
|
||||||
from utils import run_docker_compose_ps, print_service_status, check_services_state
|
from utils import run_docker_compose_ps, print_service_status, check_services_state
|
||||||
|
|
||||||
def start_arch(arch_config_file, env, log_timeout=120, check_interval=1):
|
def start_arch(arch_config_file, env, log_timeout=120):
|
||||||
"""
|
"""
|
||||||
Start Docker Compose in detached mode and stream logs until services are healthy.
|
Start Docker Compose in detached mode and stream logs until services are healthy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): The path where the prompt_confi.yml file is located.
|
path (str): The path where the prompt_confi.yml file is located.
|
||||||
log_timeout (int): Time in seconds to show logs before checking for healthy state.
|
log_timeout (int): Time in seconds to show logs before checking for healthy state.
|
||||||
check_interval (int): Time in seconds between health status checks.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
compose_file = pkg_resources.resource_filename(__name__, 'config/docker-compose.yaml')
|
compose_file = pkg_resources.resource_filename(__name__, 'config/docker-compose.yaml')
|
||||||
|
|
@ -96,3 +95,33 @@ def stop_arch():
|
||||||
|
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print(f"Failed to shut down services: {str(e)}")
|
print(f"Failed to shut down services: {str(e)}")
|
||||||
|
|
||||||
|
def start_arch_modelserver():
|
||||||
|
"""
|
||||||
|
Start the model server. This assumes that the archgw_modelserver package is installed locally
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
['archgw_modelserver', 'restart'],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
print("Successfull run the archgw model_server")
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print (f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def stop_arch_modelserver():
|
||||||
|
"""
|
||||||
|
Stop the model server. This assumes that the archgw_modelserver package is installed locally
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
['archgw_modelserver', 'stop'],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
print("Successfull stopped the archgw model_server")
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print (f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||||
|
sys.exit(1)
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,8 @@ WORKDIR /src
|
||||||
ENV MODELS="BAAI/bge-large-en-v1.5"
|
ENV MODELS="BAAI/bge-large-en-v1.5"
|
||||||
|
|
||||||
COPY ./app ./app
|
COPY ./app ./app
|
||||||
COPY ./guard_model_config.yaml .
|
COPY ./app/guard_model_config.yaml .
|
||||||
COPY ./openai_params.yaml .
|
COPY ./app/openai_params.yaml .
|
||||||
|
|
||||||
# comment it out for now as we don't want to download the model every time we build the image
|
# comment it out for now as we don't want to download the model every time we build the image
|
||||||
# we will mount host cache to docker image to avoid downloading the model every time
|
# we will mount host cache to docker image to avoid downloading the model every time
|
||||||
|
|
|
||||||
|
|
@ -44,14 +44,8 @@ RUN if command -v nvcc >/dev/null 2>&1; then \
|
||||||
|
|
||||||
COPY . /src
|
COPY . /src
|
||||||
|
|
||||||
#
|
|
||||||
# output
|
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
# Specify list of models that will go into the image as a comma separated list
|
# Specify list of models that will go into the image as a comma separated list
|
||||||
ENV MODELS="BAAI/bge-large-en-v1.5"
|
ENV MODELS="BAAI/bge-large-en-v1.5"
|
||||||
ENV NER_MODELS="urchade/gliner_large-v2.1"
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
COPY /app /app
|
COPY /app /app
|
||||||
|
|
|
||||||
1
model_server/README.md
Normal file
1
model_server/README.md
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
## Model Server Package ##
|
||||||
0
model_server/__init__.py
Normal file
0
model_server/__init__.py
Normal file
|
|
@ -0,0 +1,99 @@
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
import psutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
# Path to the file where the server process ID will be stored
|
||||||
|
PID_FILE = os.path.join(tempfile.gettempdir(), "model_server.pid")
|
||||||
|
|
||||||
|
def run_server():
|
||||||
|
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
action = sys.argv[1]
|
||||||
|
else:
|
||||||
|
action = "start"
|
||||||
|
|
||||||
|
if action == "start":
|
||||||
|
start_server()
|
||||||
|
elif action == "stop":
|
||||||
|
stop_server()
|
||||||
|
elif action == "restart":
|
||||||
|
restart_server()
|
||||||
|
else:
|
||||||
|
print(f"Unknown action: {action}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def start_server():
|
||||||
|
"""Start the Uvicorn server and save the process ID."""
|
||||||
|
if os.path.exists(PID_FILE):
|
||||||
|
print("Server is already running. Use 'model_server restart' to restart it.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Starting Archgw Model Server")
|
||||||
|
process = subprocess.Popen(
|
||||||
|
["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if wait_for_health_check("http://0.0.0.0:51000/healthz"):
|
||||||
|
# Write the process ID to the PID file
|
||||||
|
with open(PID_FILE, "w") as f:
|
||||||
|
f.write(str(process.pid))
|
||||||
|
print(f"ARCH GW Model Server started with PID {process.pid}")
|
||||||
|
else:
|
||||||
|
#Add model_server boot-up logs
|
||||||
|
print(f"ARCH GW Model Server - Didn't Sart In Time. Shutting Down")
|
||||||
|
process.terminate()
|
||||||
|
|
||||||
|
def wait_for_health_check(url, timeout=180):
|
||||||
|
"""Wait for the Uvicorn server to respond to health-check requests."""
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
try:
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return True
|
||||||
|
except requests.ConnectionError:
|
||||||
|
time.sleep(1)
|
||||||
|
print("Timed out waiting for ARCH GW Model Server to respond.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def stop_server():
|
||||||
|
"""Stop the running Uvicorn server."""
|
||||||
|
if not os.path.exists(PID_FILE):
|
||||||
|
print("Status: Archgw Model Server not running")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Read the process ID from the PID file
|
||||||
|
with open(PID_FILE, "r") as f:
|
||||||
|
pid = int(f.read())
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get process by PID
|
||||||
|
process = psutil.Process(pid)
|
||||||
|
|
||||||
|
# Gracefully terminate the process
|
||||||
|
process.terminate() # Sends SIGTERM by default
|
||||||
|
process.wait(timeout=10) # Wait for up to 10 seconds for the process to exit
|
||||||
|
|
||||||
|
print(f"Server with PID {pid} stopped.")
|
||||||
|
os.remove(PID_FILE)
|
||||||
|
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
print(f"Process with PID {pid} not found. Cleaning up PID file.")
|
||||||
|
os.remove(PID_FILE)
|
||||||
|
except psutil.TimeoutExpired:
|
||||||
|
print(f"Process with PID {pid} did not terminate in time. Forcing shutdown.")
|
||||||
|
process.kill() # Forcefully kill the process
|
||||||
|
os.remove(PID_FILE)
|
||||||
|
|
||||||
|
def restart_server():
|
||||||
|
"""Restart the Uvicorn server."""
|
||||||
|
print("Check: Is Archgw Model Server running?")
|
||||||
|
stop_server()
|
||||||
|
start_server()
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
from fastapi import FastAPI, Response
|
from fastapi import FastAPI, Response
|
||||||
from app.arch_fc.arch_handler import ArchHandler
|
from .common import ChatMessage, Message
|
||||||
from app.arch_fc.bolt_handler import BoltHandler
|
from .arch_handler import ArchHandler
|
||||||
from app.arch_fc.common import ChatMessage, Message
|
from .bolt_handler import BoltHandler
|
||||||
|
from app.utils import load_yaml_config
|
||||||
import logging
|
import logging
|
||||||
import yaml
|
import yaml
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
@ -14,17 +15,14 @@ logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
params = load_yaml_config("openai_params.yaml")
|
||||||
with open("openai_params.yaml") as f:
|
|
||||||
params = yaml.safe_load(f)
|
|
||||||
|
|
||||||
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost")
|
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost")
|
||||||
ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M")
|
ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M")
|
||||||
fc_url = os.getenv("FC_URL", ollama_endpoint)
|
fc_url = os.getenv("FC_URL", "https://arch-fc-free-trial-4mzywewe.uc.gateway.dev/v1")
|
||||||
|
|
||||||
mode = os.getenv("MODE", "cloud")
|
mode = os.getenv("MODE", "cloud")
|
||||||
if mode not in ["cloud", "local-gpu", "local-cpu"]:
|
if mode not in ["cloud", "local-gpu", "local-cpu"]:
|
||||||
raise ValueError(f"Invalid mode: {mode}")
|
raise ValueError(f"Invalid mode: {mode}")
|
||||||
arch_api_key = os.getenv("ARCH_API_KEY", "vllm")
|
|
||||||
|
|
||||||
handler = None
|
handler = None
|
||||||
if ollama_model.startswith("Arch"):
|
if ollama_model.startswith("Arch"):
|
||||||
|
|
|
||||||
|
|
@ -1,83 +0,0 @@
|
||||||
import pandas as pd
|
|
||||||
import random
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
|
|
||||||
def generate_employee_data(conn):
|
|
||||||
# List of possible names, positions, departments, and locations
|
|
||||||
names = [
|
|
||||||
"Alice",
|
|
||||||
"Bob",
|
|
||||||
"Charlie",
|
|
||||||
"David",
|
|
||||||
"Eve",
|
|
||||||
"Frank",
|
|
||||||
"Grace",
|
|
||||||
"Hank",
|
|
||||||
"Ivy",
|
|
||||||
"Jack",
|
|
||||||
]
|
|
||||||
positions = [
|
|
||||||
"Manager",
|
|
||||||
"Engineer",
|
|
||||||
"Salesperson",
|
|
||||||
"HR Specialist",
|
|
||||||
"Marketing Analyst",
|
|
||||||
]
|
|
||||||
departments = ["Engineering", "Marketing", "HR", "Sales", "Finance"]
|
|
||||||
locations = ["New York", "San Francisco", "Austin", "Boston", "Chicago"]
|
|
||||||
|
|
||||||
# Function to generate random hire date
|
|
||||||
def random_hire_date():
|
|
||||||
start_date = datetime.date(2000, 1, 1)
|
|
||||||
end_date = datetime.date(2023, 12, 31)
|
|
||||||
time_between_dates = end_date - start_date
|
|
||||||
days_between_dates = time_between_dates.days
|
|
||||||
random_number_of_days = random.randrange(days_between_dates)
|
|
||||||
hire_date = start_date + datetime.timedelta(days=random_number_of_days)
|
|
||||||
return hire_date
|
|
||||||
|
|
||||||
# Function to generate random employee data
|
|
||||||
def generate_employee_records(count):
|
|
||||||
employees = []
|
|
||||||
|
|
||||||
for _ in range(count):
|
|
||||||
name = random.choice(names)
|
|
||||||
position = random.choice(positions)
|
|
||||||
salary = round(
|
|
||||||
random.uniform(50000, 150000), 2
|
|
||||||
) # Salary between 50,000 and 150,000
|
|
||||||
department = random.choice(departments)
|
|
||||||
location = random.choice(locations)
|
|
||||||
hire_date = random_hire_date()
|
|
||||||
performance_score = round(
|
|
||||||
random.uniform(1, 5), 2
|
|
||||||
) # Performance score between 1.0 and 5.0
|
|
||||||
years_of_experience = random.randint(
|
|
||||||
1, 30
|
|
||||||
) # Years of experience between 1 and 30
|
|
||||||
|
|
||||||
employee = {
|
|
||||||
"position": position,
|
|
||||||
"name": name,
|
|
||||||
"salary": salary,
|
|
||||||
"department": department,
|
|
||||||
"location": location,
|
|
||||||
"hire_date": hire_date,
|
|
||||||
"performance_score": performance_score,
|
|
||||||
"years_of_experience": years_of_experience,
|
|
||||||
}
|
|
||||||
|
|
||||||
employees.append(employee)
|
|
||||||
|
|
||||||
return employees
|
|
||||||
|
|
||||||
# Generate 10 random employee records
|
|
||||||
employee_records = generate_employee_records(200)
|
|
||||||
|
|
||||||
# Convert the list of dictionaries to a DataFrame
|
|
||||||
df = pd.DataFrame(employee_records)
|
|
||||||
|
|
||||||
df.to_sql("employees", conn, index=False)
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
@ -17,8 +17,6 @@ def get_device():
|
||||||
def load_transformers(models=os.getenv("MODELS", "BAAI/bge-large-en-v1.5")):
|
def load_transformers(models=os.getenv("MODELS", "BAAI/bge-large-en-v1.5")):
|
||||||
transformers = {}
|
transformers = {}
|
||||||
device = get_device()
|
device = get_device()
|
||||||
|
|
||||||
print(f"Using device: {device}")
|
|
||||||
for model in models.split(","):
|
for model in models.split(","):
|
||||||
transformers[model] = sentence_transformers.SentenceTransformer(model, device=device)
|
transformers[model] = sentence_transformers.SentenceTransformer(model, device=device)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,10 @@ from app.load_models import (
|
||||||
load_transformers,
|
load_transformers,
|
||||||
load_guard_model,
|
load_guard_model,
|
||||||
load_zero_shot_models,
|
load_zero_shot_models,
|
||||||
|
get_device
|
||||||
)
|
)
|
||||||
import os
|
import os
|
||||||
from app.utils import GuardHandler, split_text_into_chunks
|
from app.utils import GuardHandler, split_text_into_chunks, load_yaml_config
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
import string
|
import string
|
||||||
|
|
@ -23,9 +24,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
transformers = load_transformers()
|
transformers = load_transformers()
|
||||||
zero_shot_models = load_zero_shot_models()
|
zero_shot_models = load_zero_shot_models()
|
||||||
|
guard_model_config = load_yaml_config("guard_model_config.yaml")
|
||||||
with open("guard_model_config.yaml") as f:
|
|
||||||
guard_model_config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
mode = os.getenv("MODE", "cloud")
|
mode = os.getenv("MODE", "cloud")
|
||||||
logger.info(f"Serving model mode: {mode}")
|
logger.info(f"Serving model mode: {mode}")
|
||||||
|
|
@ -48,12 +47,8 @@ class EmbeddingRequest(BaseModel):
|
||||||
|
|
||||||
@app.get("/healthz")
|
@app.get("/healthz")
|
||||||
async def healthz():
|
async def healthz():
|
||||||
import os
|
|
||||||
|
|
||||||
print(os.getcwd())
|
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/models")
|
@app.get("/models")
|
||||||
async def models():
|
async def models():
|
||||||
models = []
|
models = []
|
||||||
|
|
@ -66,6 +61,7 @@ async def models():
|
||||||
|
|
||||||
@app.post("/embeddings")
|
@app.post("/embeddings")
|
||||||
async def embedding(req: EmbeddingRequest, res: Response):
|
async def embedding(req: EmbeddingRequest, res: Response):
|
||||||
|
print(f"Embedding Call Start Time: {time.time()}")
|
||||||
if req.model not in transformers:
|
if req.model not in transformers:
|
||||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||||
|
|
||||||
|
|
@ -80,6 +76,7 @@ async def embedding(req: EmbeddingRequest, res: Response):
|
||||||
"prompt_tokens": 0,
|
"prompt_tokens": 0,
|
||||||
"total_tokens": 0,
|
"total_tokens": 0,
|
||||||
}
|
}
|
||||||
|
print(f"Embedding Call Complete Time: {time.time()}")
|
||||||
return {"data": data, "model": req.model, "object": "list", "usage": usage}
|
return {"data": data, "model": req.model, "object": "list", "usage": usage}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,14 @@ import numpy as np
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
|
import pkg_resources
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
def load_yaml_config(file_name):
|
||||||
|
# Load the YAML file from the package
|
||||||
|
yaml_path = pkg_resources.resource_filename('app', file_name)
|
||||||
|
with open(yaml_path, 'r') as yaml_file:
|
||||||
|
return yaml.safe_load(yaml_file)
|
||||||
|
|
||||||
|
|
||||||
def split_text_into_chunks(text, max_words=300):
|
def split_text_into_chunks(text, max_words=300):
|
||||||
|
|
@ -21,7 +29,6 @@ def split_text_into_chunks(text, max_words=300):
|
||||||
def softmax(x):
|
def softmax(x):
|
||||||
return np.exp(x) / np.exp(x).sum(axis=0)
|
return np.exp(x) / np.exp(x).sum(axis=0)
|
||||||
|
|
||||||
|
|
||||||
class PredictionHandler:
|
class PredictionHandler:
|
||||||
def __init__(self, model, tokenizer, device, task="toxic", hardware_config="cpu"):
|
def __init__(self, model, tokenizer, device, task="toxic", hardware_config="cpu"):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
|
||||||
3144
model_server/poetry.lock
generated
Normal file
3144
model_server/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
37
model_server/pyproject.toml
Normal file
37
model_server/pyproject.toml
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
[tool.poetry]
|
||||||
|
name = "archgw_modelserver"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "A model server for serving models"
|
||||||
|
authors = ["Katanemo Labs, Inc <archgw@katanemo.com>"]
|
||||||
|
license = "Apache 2.0"
|
||||||
|
readme = "README.md"
|
||||||
|
packages = [
|
||||||
|
{ include = "app" }, # Include the 'app' package
|
||||||
|
{ include = "app/arch_fc" }, # Include the 'app' package
|
||||||
|
]
|
||||||
|
include = ["app/*.yaml"]
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = ">=3.10"
|
||||||
|
fastapi = "0.115.0"
|
||||||
|
sentence-transformers = "3.1.1"
|
||||||
|
torch = "2.4.1"
|
||||||
|
uvicorn = "0.31.0"
|
||||||
|
transformers = "*"
|
||||||
|
pyyaml = "6.0.2"
|
||||||
|
accelerate = "*"
|
||||||
|
psutil = "6.0.0"
|
||||||
|
optimum-intel = "*"
|
||||||
|
openvino = "*"
|
||||||
|
pandas = "*"
|
||||||
|
dateparser = "*"
|
||||||
|
openai = "1.50.2"
|
||||||
|
tf-keras = "*"
|
||||||
|
onnx = "*"
|
||||||
|
|
||||||
|
[tool.poetry.scripts]
|
||||||
|
archgw_modelserver = "app:run_server"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
@ -1,19 +1,19 @@
|
||||||
#TOOD: pin versions
|
#TOOD: pin versions
|
||||||
fastapi
|
fastapi==0.115.0
|
||||||
sentence-transformers
|
sentence-transformers==3.1.1
|
||||||
torch
|
torch==2.4.1
|
||||||
uvicorn
|
uvicorn==0.31.0
|
||||||
gliner
|
|
||||||
transformers
|
transformers
|
||||||
pyyaml
|
pyyaml==6.0.2
|
||||||
accelerate
|
accelerate
|
||||||
|
psutil==6.0.0
|
||||||
# guard inference packages
|
# guard inference packages
|
||||||
optimum-intel
|
optimum-intel
|
||||||
openvino
|
openvino
|
||||||
psutil
|
psutil
|
||||||
pandas
|
pandas
|
||||||
dateparser
|
dateparser
|
||||||
openai
|
openai==1.50.2
|
||||||
pandas
|
pandas
|
||||||
tf-keras
|
tf-keras
|
||||||
onnx
|
onnx
|
||||||
|
|
|
||||||
26
model_server/setup.py
Normal file
26
model_server/setup.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
# Function to read requirements.txt
|
||||||
|
def parse_requirements(filename):
|
||||||
|
with open(filename, 'r') as file:
|
||||||
|
return [line.strip() for line in file if line.strip() and not line.startswith("#")]
|
||||||
|
|
||||||
|
# Call the parse_requirements function to get the list of dependencies
|
||||||
|
requirements = parse_requirements('requirements.txt')
|
||||||
|
print(f"packages to install: {find_packages()}")
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="model_server",
|
||||||
|
version="0.1",
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=requirements,
|
||||||
|
package_data={
|
||||||
|
# Specify the package and the data files you want to include
|
||||||
|
'app': ['/*.yaml'], # Includes all .yaml files in the config/ folder
|
||||||
|
},
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'model_server=app:run_server',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue