diff --git a/.gitignore b/.gitignore index 4fa6c320..f67b2130 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ envoyfilter/target +envoyfilter/qdrant_data/ +embedding-server/venv/ diff --git a/embedding-server/Dockerfile b/embedding-server/Dockerfile new file mode 100644 index 00000000..0ec28ba7 --- /dev/null +++ b/embedding-server/Dockerfile @@ -0,0 +1,42 @@ +# copied from https://github.com/bergos/embedding-server + +FROM python:3 AS base + +# +# builder +# +FROM base AS builder + +WORKDIR /src + +COPY requirements.txt /src/ +RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt + +COPY . /src + +# +# output +# + +FROM python:3-slim AS output + +# specify list of models that will go into the image as a comma separated list +# following models have been tested to work with this image +# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small" +ENV MODELS="BAAI/bge-large-en-v1.5" + +COPY --from=builder /runtime /usr/local + +COPY /app /app +WORKDIR /app + +RUN apt-get update && apt-get install -y \ + curl \ + && rm -rf /var/lib/apt/lists/* + +RUN python install.py +# RUN python install.py && \ +# find /root/.cache/torch/sentence_transformers/ -name onnx -exec rm -rf {} + + + +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"] diff --git a/embedding-server/app/install.py b/embedding-server/app/install.py new file mode 100644 index 00000000..15cacc91 --- /dev/null +++ b/embedding-server/app/install.py @@ -0,0 +1,3 @@ +from load_transformers import load_transformers + +load_transformers() diff --git a/embedding-server/app/load_transformers.py b/embedding-server/app/load_transformers.py new file mode 100644 index 00000000..052f7e0e --- /dev/null +++ b/embedding-server/app/load_transformers.py @@ -0,0 +1,10 @@ +import os +import sentence_transformers + +def load_transformers(models = os.getenv("MODELS", "sentence-transformers/all-MiniLM-L6-v2")): + transformers = {} + + for model in models.split(','): + transformers[model] = sentence_transformers.SentenceTransformer(model) + + return transformers diff --git a/embedding-server/app/main.py b/embedding-server/app/main.py new file mode 100644 index 00000000..4a2a7a67 --- /dev/null +++ b/embedding-server/app/main.py @@ -0,0 +1,48 @@ +from fastapi import FastAPI, Response, HTTPException +from pydantic import BaseModel +from load_transformers import load_transformers + +transformers = load_transformers() + +app = FastAPI() + +class EmbeddingRequest(BaseModel): + input: str + model: str + +@app.get("/models") +async def models(): + models = [] + + for model in transformers.keys(): + models.append({ + "id": model, + "object": "model" + }) + + return { + "data": models, + "object": "list" + } + +@app.post("/embeddings") +async def embedding(req: EmbeddingRequest, res: Response): + if not req.model in transformers: + raise HTTPException(status_code=400, detail="unknown model: " + req.model) + + embeddings = transformers[req.model].encode([req.input]) + + data = [] + + for embedding in embeddings.tolist(): + data.append({ + "object": "embedding", + "embedding": embedding, + "index": len(data) + }) + + return { + "data": data, + "model": req.model, + "object": "list" + } diff --git a/embedding-server/requirements.txt b/embedding-server/requirements.txt new file mode 100644 index 00000000..613aa60e --- /dev/null +++ b/embedding-server/requirements.txt @@ -0,0 +1,5 @@ +#TOOD: pin versions +fastapi +sentence-transformers +torch +uvicorn diff --git a/envoyfilter/Cargo.lock b/envoyfilter/Cargo.lock index 89a4b636..e77cf8ee 100644 --- a/envoyfilter/Cargo.lock +++ b/envoyfilter/Cargo.lock @@ -56,6 +56,7 @@ name = "intelligent-prompt-gateway" version = "0.1.0" dependencies = [ "log", + "md5", "proxy-wasm", "serde", "serde_json", @@ -74,6 +75,12 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "once_cell" version = "1.19.0" diff --git a/envoyfilter/Cargo.toml b/envoyfilter/Cargo.toml index 1e91806b..51e436af 100644 --- a/envoyfilter/Cargo.toml +++ b/envoyfilter/Cargo.toml @@ -13,3 +13,4 @@ log = "0.4" serde = { version = "1.0", features = ["derive"] } serde_yaml = "0.9.34" serde_json = "1.0" +md5 = "0.7.0" diff --git a/envoyfilter/docker-compose.yaml b/envoyfilter/docker-compose.yaml index 124d903f..4f237ed0 100644 --- a/envoyfilter/docker-compose.yaml +++ b/envoyfilter/docker-compose.yaml @@ -10,6 +10,33 @@ services: - ./target/wasm32-wasi/release:/etc/envoy/proxy-wasm-plugins networks: - envoymesh + depends_on: + embeddingserver: + condition: service_healthy + + embeddingserver: + build: + context: ../embedding-server + dockerfile: Dockerfile + ports: + - "18080:80" + healthcheck: + test: ["CMD", "curl" ,"http://localhost:80"] + interval: 5s + retries: 20 + networks: + - envoymesh + + qdrant: + image: qdrant/qdrant + hostname: vector-db + ports: + - 16333:6333 + - 16334:6334 + volumes: + - ./qdrant_data:/qdrant/storage + networks: + - envoymesh networks: envoymesh: {} diff --git a/envoyfilter/envoy.yaml b/envoyfilter/envoy.yaml index 0cc20a90..3b616062 100644 --- a/envoyfilter/envoy.yaml +++ b/envoyfilter/envoy.yaml @@ -21,6 +21,10 @@ static_resources: domains: - "*" routes: + - match: + prefix: "/embeddings" + route: + cluster: embeddingserver - match: prefix: "/inline" route: @@ -98,3 +102,31 @@ static_resources: address: httpbin.org port_value: 80 hostname: "httpbin.org" + - name: embeddingserver + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: embeddingserver + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: embeddingserver + port_value: 80 + hostname: "embeddingserver" + - name: qdrant + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: qdrant + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: qdrant + port_value: 6333 + hostname: "qdrant" diff --git a/envoyfilter/init_vector_store.sh b/envoyfilter/init_vector_store.sh new file mode 100644 index 00000000..e22f46b0 --- /dev/null +++ b/envoyfilter/init_vector_store.sh @@ -0,0 +1,16 @@ +#!/bin/sh + +echo 'Deleting prompt_vector_store collection' +curl -X DELETE http://localhost:16333/collections/prompt_vector_store +echo +echo 'Creating prompt_vector_store collection' +curl -X PUT 'http://localhost:16333/collections/prompt_vector_store' \ + -H 'Content-Type: application/json' \ + --data-raw '{ + "vectors": { + "size": 1024, + "distance": "Cosine" + } + }' +echo +echo 'Created prompt_vector_store collection' diff --git a/envoyfilter/src/common_types.rs b/envoyfilter/src/common_types.rs new file mode 100644 index 00000000..58621b81 --- /dev/null +++ b/envoyfilter/src/common_types.rs @@ -0,0 +1,54 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::configuration; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CreateEmbeddingRequest { + pub input: String, + pub model: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CreateEmbeddingResponse { + pub object: String, + pub model: String, + pub data: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Embedding { + pub object: String, + pub index: i32, + pub embedding: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmbeddingRequest { + pub create_embedding_request: CreateEmbeddingRequest, + pub prompt_target: configuration::PromptTarget, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum MessageType { + EmbeddingRequest(EmbeddingRequest), + CreateVectorStorePoints(CreateVectorStorePoints), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CalloutData { + pub message: MessageType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VectorPoint { + pub id: String, + pub payload: HashMap, + pub vector: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CreateVectorStorePoints { + pub points: Vec, +} diff --git a/envoyfilter/src/configuration.rs b/envoyfilter/src/configuration.rs index 4b805595..e01162f3 100644 --- a/envoyfilter/src/configuration.rs +++ b/envoyfilter/src/configuration.rs @@ -29,6 +29,7 @@ pub struct PromptConfig { pub embedding_provider: EmbeddingProviver, pub llm_providers: Vec, pub system_prompt: String, + pub prompt_targets: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/envoyfilter/src/consts.rs b/envoyfilter/src/consts.rs new file mode 100644 index 00000000..a403fbb8 --- /dev/null +++ b/envoyfilter/src/consts.rs @@ -0,0 +1 @@ +pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5"; diff --git a/envoyfilter/src/lib.rs b/envoyfilter/src/lib.rs index 9cf4729a..4eb33c4f 100644 --- a/envoyfilter/src/lib.rs +++ b/envoyfilter/src/lib.rs @@ -1,9 +1,14 @@ +mod common_types; mod configuration; +mod consts; +use common_types::EmbeddingRequest; use log::info; +use serde_json::to_string; use stats::IncrementingMetric; use stats::Metric; use stats::RecordingMetric; +use std::collections::HashMap; use std::time::Duration; use proxy_wasm::traits::*; @@ -15,6 +20,7 @@ proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_root_context(|_| -> Box { Box::new(HttpHeaderRoot { + callouts: HashMap::new(), config: None, metrics: WasmMetrics { counter: stats::Counter::new(String::from("wasm_counter")), @@ -86,6 +92,7 @@ impl HttpContext for HttpHeader { impl Context for HttpHeader { // Note that the event driven model continues here from the return of the on_http_request_headers above. fn on_http_call_response(&mut self, _: u32, _: usize, body_size: usize, _: usize) { + info!("on_http_call_response: body_size = {}", body_size); if let Some(body) = self.get_http_call_response_body(0, body_size) { if !body.is_empty() && body[0] % 2 == 0 { info!("Access granted."); @@ -117,25 +124,111 @@ struct WasmMetrics { struct HttpHeaderRoot { metrics: WasmMetrics, + // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. + callouts: HashMap, config: Option, } -impl Context for HttpHeaderRoot {} +impl Context for HttpHeaderRoot { + fn on_http_call_response( + &mut self, + token_id: u32, + _num_headers: usize, + body_size: usize, + _num_trailers: usize, + ) { + info!("on_http_call_response: token_id = {}", token_id); + + let callout_data = self.callouts.remove(&token_id).expect("invalid token_id"); + + match callout_data.message { + common_types::MessageType::EmbeddingRequest(common_types::EmbeddingRequest { + create_embedding_request, + prompt_target, + }) => { + info!("response received for CreateEmbeddingRequest"); + if let Some(body) = self.get_http_call_response_body(0, body_size) { + if !body.is_empty() { + let embedding_response: common_types::CreateEmbeddingResponse = + serde_json::from_slice(&body).unwrap(); + info!( + "embedding_response model: {}, vector len: {}", + embedding_response.model, + embedding_response.data[0].embedding.len() + ); + + let mut payload: HashMap = HashMap::new(); + payload.insert( + "prompt-target".to_string(), + to_string(&prompt_target).unwrap(), + ); + payload.insert( + "few-shot-example".to_string(), + create_embedding_request.input.clone(), + ); + + let id = md5::compute(create_embedding_request.input); + + let create_vector_store_points = common_types::CreateVectorStorePoints { + points: vec![common_types::VectorPoint { + id: format!("{:x}", id), + payload, + vector: embedding_response.data[0].embedding.clone(), + }], + }; + let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors + info!( + "create_vector_store_points: points length: {}", + embedding_response.data[0].embedding.len() + ); + let token_id = match self.dispatch_http_call( + "qdrant", + vec![ + (":method", "PUT"), + (":path", "/collections/prompt_vector_store/points"), + (":authority", "qdrant"), + ("content-type", "application/json"), + ], + Some(json_data.as_bytes()), + vec![], + Duration::from_secs(5), + ) { + Ok(token_id) => token_id, + Err(e) => { + panic!("Error dispatching HTTP call: {:?}", e); + } + }; + info!("on_tick: dispatched HTTP call with token_id = {}", token_id); + + let callout_message = common_types::CalloutData { + message: common_types::MessageType::CreateVectorStorePoints( + create_vector_store_points, + ), + }; + if self.callouts.insert(token_id, callout_message).is_some() { + panic!("duplicate token_id") + } + } + } + } + common_types::MessageType::CreateVectorStorePoints(_) => { + info!("response received for CreateVectorStorePoints"); + if let Some(body) = self.get_http_call_response_body(0, body_size) { + if !body.is_empty() { + info!("response body: {:?}", String::from_utf8(body).unwrap()); + } + } + } + } + } +} // RootContext allows the Rust code to reach into the Envoy Config impl RootContext for HttpHeaderRoot { - fn on_configure(&mut self, plugin_configuration_size: usize) -> bool { - info!( - "on_configure: plugin_configuration_size is {}", - plugin_configuration_size - ); - + fn on_configure(&mut self, _: usize) -> bool { if let Some(config_bytes) = self.get_plugin_configuration() { - let config_str = String::from_utf8(config_bytes).unwrap(); - info!("on_configure: plugin configuration is {:?}", config_str); - self.config = serde_yaml::from_str(&config_str).unwrap(); + self.config = serde_yaml::from_slice(&config_bytes).unwrap(); info!("on_configure: plugin configuration loaded"); - info!("on_configure: {:?}", self.config); } true } @@ -151,4 +244,56 @@ impl RootContext for HttpHeaderRoot { fn get_type(&self) -> Option { Some(ContextType::HttpContext) } + + fn on_vm_start(&mut self, _: usize) -> bool { + info!("on_vm_start: setting up tick timeout"); + self.set_tick_period(Duration::from_secs(1)); + true + } + + fn on_tick(&mut self) { + info!("on_tick: starting to process prompt targets"); + for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets { + for few_shot_example in &prompt_target.few_shot_examples { + info!("few_shot_example: {:?}", few_shot_example); + let embeddings_input = common_types::CreateEmbeddingRequest { + input: few_shot_example.to_string(), + model: String::from(consts::DEFAULT_EMBEDDING_MODEL), + }; + + // TODO: Handle potential errors + let json_data = to_string(&embeddings_input).unwrap(); + + let token_id = match self.dispatch_http_call( + "embeddingserver", + vec![ + (":method", "POST"), + (":path", "/embeddings"), + (":authority", "embeddingserver"), + ("content-type", "application/json"), + ], + Some(json_data.as_bytes()), + vec![], + Duration::from_secs(5), + ) { + Ok(token_id) => token_id, + Err(e) => { + panic!("Error dispatching HTTP call: {:?}", e); + } + }; + info!("on_tick: dispatched HTTP call with token_id = {}", token_id); + let embedding_request = EmbeddingRequest { + create_embedding_request: embeddings_input, + prompt_target: prompt_target.clone(), + }; + let callout_message = common_types::CalloutData { + message: common_types::MessageType::EmbeddingRequest(embedding_request), + }; + if self.callouts.insert(token_id, callout_message).is_some() { + panic!("duplicate token_id") + } + } + } + self.set_tick_period(Duration::from_secs(0)); + } } diff --git a/gateway.code-workspace b/gateway.code-workspace index 98d02f91..939a7d1a 100644 --- a/gateway.code-workspace +++ b/gateway.code-workspace @@ -8,6 +8,10 @@ "name": "envoyfilter", "path": "envoyfilter" }, + { + "name": "embedding-server", + "path": "embedding-server" + }, { "name": "demos", "path": "./demos"