add embedding store (#10)

This commit is contained in:
Adil Hafeez 2024-07-18 14:04:51 -07:00 committed by GitHub
parent cc2a496f90
commit 7bf77afa0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 409 additions and 11 deletions

View file

@ -0,0 +1,42 @@
# copied from https://github.com/bergos/embedding-server
FROM python:3 AS base
#
# builder
#
FROM base AS builder
WORKDIR /src
COPY requirements.txt /src/
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
COPY . /src
#
# output
#
FROM python:3-slim AS output
# specify list of models that will go into the image as a comma separated list
# following models have been tested to work with this image
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
ENV MODELS="BAAI/bge-large-en-v1.5"
COPY --from=builder /runtime /usr/local
COPY /app /app
WORKDIR /app
RUN apt-get update && apt-get install -y \
curl \
&& rm -rf /var/lib/apt/lists/*
RUN python install.py
# RUN python install.py && \
# find /root/.cache/torch/sentence_transformers/ -name onnx -exec rm -rf {} +
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]

View file

@ -0,0 +1,3 @@
from load_transformers import load_transformers
load_transformers()

View file

@ -0,0 +1,10 @@
import os
import sentence_transformers
def load_transformers(models = os.getenv("MODELS", "sentence-transformers/all-MiniLM-L6-v2")):
transformers = {}
for model in models.split(','):
transformers[model] = sentence_transformers.SentenceTransformer(model)
return transformers

View file

@ -0,0 +1,48 @@
from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel
from load_transformers import load_transformers
transformers = load_transformers()
app = FastAPI()
class EmbeddingRequest(BaseModel):
input: str
model: str
@app.get("/models")
async def models():
models = []
for model in transformers.keys():
models.append({
"id": model,
"object": "model"
})
return {
"data": models,
"object": "list"
}
@app.post("/embeddings")
async def embedding(req: EmbeddingRequest, res: Response):
if not req.model in transformers:
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
embeddings = transformers[req.model].encode([req.input])
data = []
for embedding in embeddings.tolist():
data.append({
"object": "embedding",
"embedding": embedding,
"index": len(data)
})
return {
"data": data,
"model": req.model,
"object": "list"
}

View file

@ -0,0 +1,5 @@
#TOOD: pin versions
fastapi
sentence-transformers
torch
uvicorn