add embedding store (#10)

This commit is contained in:
Adil Hafeez 2024-07-18 14:04:51 -07:00 committed by GitHub
parent cc2a496f90
commit 7bf77afa0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 409 additions and 11 deletions

2
.gitignore vendored
View file

@ -1 +1,3 @@
envoyfilter/target
envoyfilter/qdrant_data/
embedding-server/venv/

View file

@ -0,0 +1,42 @@
# copied from https://github.com/bergos/embedding-server
FROM python:3 AS base
#
# builder
#
FROM base AS builder
WORKDIR /src
COPY requirements.txt /src/
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
COPY . /src
#
# output
#
FROM python:3-slim AS output
# specify list of models that will go into the image as a comma separated list
# following models have been tested to work with this image
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
ENV MODELS="BAAI/bge-large-en-v1.5"
COPY --from=builder /runtime /usr/local
COPY /app /app
WORKDIR /app
RUN apt-get update && apt-get install -y \
curl \
&& rm -rf /var/lib/apt/lists/*
RUN python install.py
# RUN python install.py && \
# find /root/.cache/torch/sentence_transformers/ -name onnx -exec rm -rf {} +
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]

View file

@ -0,0 +1,3 @@
from load_transformers import load_transformers
load_transformers()

View file

@ -0,0 +1,10 @@
import os
import sentence_transformers
def load_transformers(models = os.getenv("MODELS", "sentence-transformers/all-MiniLM-L6-v2")):
transformers = {}
for model in models.split(','):
transformers[model] = sentence_transformers.SentenceTransformer(model)
return transformers

View file

@ -0,0 +1,48 @@
from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel
from load_transformers import load_transformers
transformers = load_transformers()
app = FastAPI()
class EmbeddingRequest(BaseModel):
input: str
model: str
@app.get("/models")
async def models():
models = []
for model in transformers.keys():
models.append({
"id": model,
"object": "model"
})
return {
"data": models,
"object": "list"
}
@app.post("/embeddings")
async def embedding(req: EmbeddingRequest, res: Response):
if not req.model in transformers:
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
embeddings = transformers[req.model].encode([req.input])
data = []
for embedding in embeddings.tolist():
data.append({
"object": "embedding",
"embedding": embedding,
"index": len(data)
})
return {
"data": data,
"model": req.model,
"object": "list"
}

View file

@ -0,0 +1,5 @@
#TOOD: pin versions
fastapi
sentence-transformers
torch
uvicorn

View file

@ -56,6 +56,7 @@ name = "intelligent-prompt-gateway"
version = "0.1.0"
dependencies = [
"log",
"md5",
"proxy-wasm",
"serde",
"serde_json",
@ -74,6 +75,12 @@ version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]]
name = "md5"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]]
name = "once_cell"
version = "1.19.0"

View file

@ -13,3 +13,4 @@ log = "0.4"
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9.34"
serde_json = "1.0"
md5 = "0.7.0"

View file

@ -10,6 +10,33 @@ services:
- ./target/wasm32-wasi/release:/etc/envoy/proxy-wasm-plugins
networks:
- envoymesh
depends_on:
embeddingserver:
condition: service_healthy
embeddingserver:
build:
context: ../embedding-server
dockerfile: Dockerfile
ports:
- "18080:80"
healthcheck:
test: ["CMD", "curl" ,"http://localhost:80"]
interval: 5s
retries: 20
networks:
- envoymesh
qdrant:
image: qdrant/qdrant
hostname: vector-db
ports:
- 16333:6333
- 16334:6334
volumes:
- ./qdrant_data:/qdrant/storage
networks:
- envoymesh
networks:
envoymesh: {}

View file

@ -21,6 +21,10 @@ static_resources:
domains:
- "*"
routes:
- match:
prefix: "/embeddings"
route:
cluster: embeddingserver
- match:
prefix: "/inline"
route:
@ -98,3 +102,31 @@ static_resources:
address: httpbin.org
port_value: 80
hostname: "httpbin.org"
- name: embeddingserver
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: embeddingserver
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: embeddingserver
port_value: 80
hostname: "embeddingserver"
- name: qdrant
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: qdrant
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: qdrant
port_value: 6333
hostname: "qdrant"

View file

@ -0,0 +1,16 @@
#!/bin/sh
echo 'Deleting prompt_vector_store collection'
curl -X DELETE http://localhost:16333/collections/prompt_vector_store
echo
echo 'Creating prompt_vector_store collection'
curl -X PUT 'http://localhost:16333/collections/prompt_vector_store' \
-H 'Content-Type: application/json' \
--data-raw '{
"vectors": {
"size": 1024,
"distance": "Cosine"
}
}'
echo
echo 'Created prompt_vector_store collection'

View file

@ -0,0 +1,54 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::configuration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateEmbeddingRequest {
pub input: String,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateEmbeddingResponse {
pub object: String,
pub model: String,
pub data: Vec<Embedding>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Embedding {
pub object: String,
pub index: i32,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub create_embedding_request: CreateEmbeddingRequest,
pub prompt_target: configuration::PromptTarget,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MessageType {
EmbeddingRequest(EmbeddingRequest),
CreateVectorStorePoints(CreateVectorStorePoints),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalloutData {
pub message: MessageType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorPoint {
pub id: String,
pub payload: HashMap<String, String>,
pub vector: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateVectorStorePoints {
pub points: Vec<VectorPoint>,
}

View file

@ -29,6 +29,7 @@ pub struct PromptConfig {
pub embedding_provider: EmbeddingProviver,
pub llm_providers: Vec<LlmProvider>,
pub system_prompt: String,
pub prompt_targets: Vec<PromptTarget>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -0,0 +1 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";

View file

@ -1,9 +1,14 @@
mod common_types;
mod configuration;
mod consts;
use common_types::EmbeddingRequest;
use log::info;
use serde_json::to_string;
use stats::IncrementingMetric;
use stats::Metric;
use stats::RecordingMetric;
use std::collections::HashMap;
use std::time::Duration;
use proxy_wasm::traits::*;
@ -15,6 +20,7 @@ proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(HttpHeaderRoot {
callouts: HashMap::new(),
config: None,
metrics: WasmMetrics {
counter: stats::Counter::new(String::from("wasm_counter")),
@ -86,6 +92,7 @@ impl HttpContext for HttpHeader {
impl Context for HttpHeader {
// Note that the event driven model continues here from the return of the on_http_request_headers above.
fn on_http_call_response(&mut self, _: u32, _: usize, body_size: usize, _: usize) {
info!("on_http_call_response: body_size = {}", body_size);
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() && body[0] % 2 == 0 {
info!("Access granted.");
@ -117,25 +124,111 @@ struct WasmMetrics {
struct HttpHeaderRoot {
metrics: WasmMetrics,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, common_types::CalloutData>,
config: Option<configuration::Configuration>,
}
impl Context for HttpHeaderRoot {}
impl Context for HttpHeaderRoot {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
info!("on_http_call_response: token_id = {}", token_id);
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
match callout_data.message {
common_types::MessageType::EmbeddingRequest(common_types::EmbeddingRequest {
create_embedding_request,
prompt_target,
}) => {
info!("response received for CreateEmbeddingRequest");
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
let embedding_response: common_types::CreateEmbeddingResponse =
serde_json::from_slice(&body).unwrap();
info!(
"embedding_response model: {}, vector len: {}",
embedding_response.model,
embedding_response.data[0].embedding.len()
);
let mut payload: HashMap<String, String> = HashMap::new();
payload.insert(
"prompt-target".to_string(),
to_string(&prompt_target).unwrap(),
);
payload.insert(
"few-shot-example".to_string(),
create_embedding_request.input.clone(),
);
let id = md5::compute(create_embedding_request.input);
let create_vector_store_points = common_types::CreateVectorStorePoints {
points: vec![common_types::VectorPoint {
id: format!("{:x}", id),
payload,
vector: embedding_response.data[0].embedding.clone(),
}],
};
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
info!(
"create_vector_store_points: points length: {}",
embedding_response.data[0].embedding.len()
);
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store/points"),
(":authority", "qdrant"),
("content-type", "application/json"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
let callout_message = common_types::CalloutData {
message: common_types::MessageType::CreateVectorStorePoints(
create_vector_store_points,
),
};
if self.callouts.insert(token_id, callout_message).is_some() {
panic!("duplicate token_id")
}
}
}
}
common_types::MessageType::CreateVectorStorePoints(_) => {
info!("response received for CreateVectorStorePoints");
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
info!("response body: {:?}", String::from_utf8(body).unwrap());
}
}
}
}
}
}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for HttpHeaderRoot {
fn on_configure(&mut self, plugin_configuration_size: usize) -> bool {
info!(
"on_configure: plugin_configuration_size is {}",
plugin_configuration_size
);
fn on_configure(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_plugin_configuration() {
let config_str = String::from_utf8(config_bytes).unwrap();
info!("on_configure: plugin configuration is {:?}", config_str);
self.config = serde_yaml::from_str(&config_str).unwrap();
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
info!("on_configure: plugin configuration loaded");
info!("on_configure: {:?}", self.config);
}
true
}
@ -151,4 +244,56 @@ impl RootContext for HttpHeaderRoot {
fn get_type(&self) -> Option<ContextType> {
Some(ContextType::HttpContext)
}
fn on_vm_start(&mut self, _: usize) -> bool {
info!("on_vm_start: setting up tick timeout");
self.set_tick_period(Duration::from_secs(1));
true
}
fn on_tick(&mut self) {
info!("on_tick: starting to process prompt targets");
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
for few_shot_example in &prompt_target.few_shot_examples {
info!("few_shot_example: {:?}", few_shot_example);
let embeddings_input = common_types::CreateEmbeddingRequest {
input: few_shot_example.to_string(),
model: String::from(consts::DEFAULT_EMBEDDING_MODEL),
};
// TODO: Handle potential errors
let json_data = to_string(&embeddings_input).unwrap();
let token_id = match self.dispatch_http_call(
"embeddingserver",
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
let embedding_request = EmbeddingRequest {
create_embedding_request: embeddings_input,
prompt_target: prompt_target.clone(),
};
let callout_message = common_types::CalloutData {
message: common_types::MessageType::EmbeddingRequest(embedding_request),
};
if self.callouts.insert(token_id, callout_message).is_some() {
panic!("duplicate token_id")
}
}
}
self.set_tick_period(Duration::from_secs(0));
}
}

View file

@ -8,6 +8,10 @@
"name": "envoyfilter",
"path": "envoyfilter"
},
{
"name": "embedding-server",
"path": "embedding-server"
},
{
"name": "demos",
"path": "./demos"