add support for default target (#111)

* add support for default target

* add more fixes
This commit is contained in:
Adil Hafeez 2024-10-02 20:43:16 -07:00 committed by GitHub
parent c8d0dbec26
commit 1b57a49c9d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 215 additions and 88 deletions

View file

@ -1,41 +1,25 @@
FROM python:3.10 AS base
FROM python:3.10 AS builder
#
# builder
#
FROM base AS builder
WORKDIR /src
RUN pip install --upgrade pip
# Install git (needed for cloning the repository)
RUN apt-get update && apt-get install -y git && apt-get clean
COPY requirements.txt /src/
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
COPY . /src
#
# output
#
COPY requirements.txt .
RUN pip install --prefix=/runtime -r requirements.txt
FROM python:3.10-slim AS output
# curl is needed for health check in docker-compose
RUN apt-get update && apt-get install -y curl && apt-get clean && rm -rf /var/lib/apt/lists/*
COPY --from=builder /runtime /usr/local
WORKDIR /src
# specify list of models that will go into the image as a comma separated list
# following models have been tested to work with this image
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
ENV MODELS="BAAI/bge-large-en-v1.5"
COPY --from=builder /runtime /usr/local
COPY ./ /app
WORKDIR /app
RUN apt-get update && apt-get install -y \
curl \
&& rm -rf /var/lib/apt/lists/*
COPY ./app ./app
COPY ./guard_model_config.yaml .
COPY ./openai_params.yaml .
# comment it out for now as we don't want to download the model every time we build the image
# we will mount host cache to docker image to avoid downloading the model every time

View file

@ -9,6 +9,10 @@ import yaml
from openai import OpenAI
import os
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
with open("openai_params.yaml") as f:
params = yaml.safe_load(f)
@ -20,7 +24,6 @@ mode = os.getenv("MODE", "cloud")
if mode not in ["cloud", "local-gpu", "local-cpu"]:
raise ValueError(f"Invalid mode: {mode}")
arch_api_key = os.getenv("ARCH_API_KEY", "vllm")
logger = logging.getLogger("uvicorn.error")
handler = None
if ollama_model.startswith("Arch"):
@ -28,17 +31,12 @@ if ollama_model.startswith("Arch"):
else:
handler = BoltHandler()
# app = FastAPI()
if mode == "cloud":
client = OpenAI(
base_url=fc_url,
api_key="EMPTY",
)
models = client.models.list()
model = models.data[0].id
chosen_model = model
chosen_model = "fc-cloud"
endpoint = fc_url
else:
client = OpenAI(
@ -47,12 +45,12 @@ else:
)
chosen_model = ollama_model
endpoint = ollama_endpoint
logger.info(f"serving mode: {mode}")
logger.info(f"using model: {chosen_model}")
logger.info(f"using endpoint: {endpoint}")
async def chat_completion(req: ChatMessage, res: Response):
logger.info("starting request")
tools_encoded = handler._format_system(req.tools)