mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
add embedding store (#10)
This commit is contained in:
parent
cc2a496f90
commit
7bf77afa0e
16 changed files with 409 additions and 11 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -1 +1,3 @@
|
|||
envoyfilter/target
|
||||
envoyfilter/qdrant_data/
|
||||
embedding-server/venv/
|
||||
|
|
|
|||
42
embedding-server/Dockerfile
Normal file
42
embedding-server/Dockerfile
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# copied from https://github.com/bergos/embedding-server
|
||||
|
||||
FROM python:3 AS base
|
||||
|
||||
#
|
||||
# builder
|
||||
#
|
||||
FROM base AS builder
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
COPY requirements.txt /src/
|
||||
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
|
||||
|
||||
COPY . /src
|
||||
|
||||
#
|
||||
# output
|
||||
#
|
||||
|
||||
FROM python:3-slim AS output
|
||||
|
||||
# specify list of models that will go into the image as a comma separated list
|
||||
# following models have been tested to work with this image
|
||||
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
|
||||
ENV MODELS="BAAI/bge-large-en-v1.5"
|
||||
|
||||
COPY --from=builder /runtime /usr/local
|
||||
|
||||
COPY /app /app
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python install.py
|
||||
# RUN python install.py && \
|
||||
# find /root/.cache/torch/sentence_transformers/ -name onnx -exec rm -rf {} +
|
||||
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]
|
||||
3
embedding-server/app/install.py
Normal file
3
embedding-server/app/install.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from load_transformers import load_transformers
|
||||
|
||||
load_transformers()
|
||||
10
embedding-server/app/load_transformers.py
Normal file
10
embedding-server/app/load_transformers.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
import os
|
||||
import sentence_transformers
|
||||
|
||||
def load_transformers(models = os.getenv("MODELS", "sentence-transformers/all-MiniLM-L6-v2")):
|
||||
transformers = {}
|
||||
|
||||
for model in models.split(','):
|
||||
transformers[model] = sentence_transformers.SentenceTransformer(model)
|
||||
|
||||
return transformers
|
||||
48
embedding-server/app/main.py
Normal file
48
embedding-server/app/main.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
from fastapi import FastAPI, Response, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from load_transformers import load_transformers
|
||||
|
||||
transformers = load_transformers()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: str
|
||||
model: str
|
||||
|
||||
@app.get("/models")
|
||||
async def models():
|
||||
models = []
|
||||
|
||||
for model in transformers.keys():
|
||||
models.append({
|
||||
"id": model,
|
||||
"object": "model"
|
||||
})
|
||||
|
||||
return {
|
||||
"data": models,
|
||||
"object": "list"
|
||||
}
|
||||
|
||||
@app.post("/embeddings")
|
||||
async def embedding(req: EmbeddingRequest, res: Response):
|
||||
if not req.model in transformers:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
embeddings = transformers[req.model].encode([req.input])
|
||||
|
||||
data = []
|
||||
|
||||
for embedding in embeddings.tolist():
|
||||
data.append({
|
||||
"object": "embedding",
|
||||
"embedding": embedding,
|
||||
"index": len(data)
|
||||
})
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"model": req.model,
|
||||
"object": "list"
|
||||
}
|
||||
5
embedding-server/requirements.txt
Normal file
5
embedding-server/requirements.txt
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
#TOOD: pin versions
|
||||
fastapi
|
||||
sentence-transformers
|
||||
torch
|
||||
uvicorn
|
||||
7
envoyfilter/Cargo.lock
generated
7
envoyfilter/Cargo.lock
generated
|
|
@ -56,6 +56,7 @@ name = "intelligent-prompt-gateway"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"log",
|
||||
"md5",
|
||||
"proxy-wasm",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
@ -74,6 +75,12 @@ version = "0.4.22"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
|
||||
|
||||
[[package]]
|
||||
name = "md5"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.19.0"
|
||||
|
|
|
|||
|
|
@ -13,3 +13,4 @@ log = "0.4"
|
|||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_yaml = "0.9.34"
|
||||
serde_json = "1.0"
|
||||
md5 = "0.7.0"
|
||||
|
|
|
|||
|
|
@ -10,6 +10,33 @@ services:
|
|||
- ./target/wasm32-wasi/release:/etc/envoy/proxy-wasm-plugins
|
||||
networks:
|
||||
- envoymesh
|
||||
depends_on:
|
||||
embeddingserver:
|
||||
condition: service_healthy
|
||||
|
||||
embeddingserver:
|
||||
build:
|
||||
context: ../embedding-server
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "18080:80"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl" ,"http://localhost:80"]
|
||||
interval: 5s
|
||||
retries: 20
|
||||
networks:
|
||||
- envoymesh
|
||||
|
||||
qdrant:
|
||||
image: qdrant/qdrant
|
||||
hostname: vector-db
|
||||
ports:
|
||||
- 16333:6333
|
||||
- 16334:6334
|
||||
volumes:
|
||||
- ./qdrant_data:/qdrant/storage
|
||||
networks:
|
||||
- envoymesh
|
||||
|
||||
networks:
|
||||
envoymesh: {}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ static_resources:
|
|||
domains:
|
||||
- "*"
|
||||
routes:
|
||||
- match:
|
||||
prefix: "/embeddings"
|
||||
route:
|
||||
cluster: embeddingserver
|
||||
- match:
|
||||
prefix: "/inline"
|
||||
route:
|
||||
|
|
@ -98,3 +102,31 @@ static_resources:
|
|||
address: httpbin.org
|
||||
port_value: 80
|
||||
hostname: "httpbin.org"
|
||||
- name: embeddingserver
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: embeddingserver
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: embeddingserver
|
||||
port_value: 80
|
||||
hostname: "embeddingserver"
|
||||
- name: qdrant
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: qdrant
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: qdrant
|
||||
port_value: 6333
|
||||
hostname: "qdrant"
|
||||
|
|
|
|||
16
envoyfilter/init_vector_store.sh
Normal file
16
envoyfilter/init_vector_store.sh
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
#!/bin/sh
|
||||
|
||||
echo 'Deleting prompt_vector_store collection'
|
||||
curl -X DELETE http://localhost:16333/collections/prompt_vector_store
|
||||
echo
|
||||
echo 'Creating prompt_vector_store collection'
|
||||
curl -X PUT 'http://localhost:16333/collections/prompt_vector_store' \
|
||||
-H 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"vectors": {
|
||||
"size": 1024,
|
||||
"distance": "Cosine"
|
||||
}
|
||||
}'
|
||||
echo
|
||||
echo 'Created prompt_vector_store collection'
|
||||
54
envoyfilter/src/common_types.rs
Normal file
54
envoyfilter/src/common_types.rs
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::configuration;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingRequest {
|
||||
pub input: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingResponse {
|
||||
pub object: String,
|
||||
pub model: String,
|
||||
pub data: Vec<Embedding>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Embedding {
|
||||
pub object: String,
|
||||
pub index: i32,
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
pub create_embedding_request: CreateEmbeddingRequest,
|
||||
pub prompt_target: configuration::PromptTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum MessageType {
|
||||
EmbeddingRequest(EmbeddingRequest),
|
||||
CreateVectorStorePoints(CreateVectorStorePoints),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CalloutData {
|
||||
pub message: MessageType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorPoint {
|
||||
pub id: String,
|
||||
pub payload: HashMap<String, String>,
|
||||
pub vector: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateVectorStorePoints {
|
||||
pub points: Vec<VectorPoint>,
|
||||
}
|
||||
|
|
@ -29,6 +29,7 @@ pub struct PromptConfig {
|
|||
pub embedding_provider: EmbeddingProviver,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub system_prompt: String,
|
||||
pub prompt_targets: Vec<PromptTarget>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
|
|||
1
envoyfilter/src/consts.rs
Normal file
1
envoyfilter/src/consts.rs
Normal file
|
|
@ -0,0 +1 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
|
||||
|
|
@ -1,9 +1,14 @@
|
|||
mod common_types;
|
||||
mod configuration;
|
||||
mod consts;
|
||||
|
||||
use common_types::EmbeddingRequest;
|
||||
use log::info;
|
||||
use serde_json::to_string;
|
||||
use stats::IncrementingMetric;
|
||||
use stats::Metric;
|
||||
use stats::RecordingMetric;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use proxy_wasm::traits::*;
|
||||
|
|
@ -15,6 +20,7 @@ proxy_wasm::main! {{
|
|||
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
|
||||
Box::new(HttpHeaderRoot {
|
||||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
metrics: WasmMetrics {
|
||||
counter: stats::Counter::new(String::from("wasm_counter")),
|
||||
|
|
@ -86,6 +92,7 @@ impl HttpContext for HttpHeader {
|
|||
impl Context for HttpHeader {
|
||||
// Note that the event driven model continues here from the return of the on_http_request_headers above.
|
||||
fn on_http_call_response(&mut self, _: u32, _: usize, body_size: usize, _: usize) {
|
||||
info!("on_http_call_response: body_size = {}", body_size);
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() && body[0] % 2 == 0 {
|
||||
info!("Access granted.");
|
||||
|
|
@ -117,25 +124,111 @@ struct WasmMetrics {
|
|||
|
||||
struct HttpHeaderRoot {
|
||||
metrics: WasmMetrics,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, common_types::CalloutData>,
|
||||
config: Option<configuration::Configuration>,
|
||||
}
|
||||
|
||||
impl Context for HttpHeaderRoot {}
|
||||
impl Context for HttpHeaderRoot {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
info!("on_http_call_response: token_id = {}", token_id);
|
||||
|
||||
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
|
||||
match callout_data.message {
|
||||
common_types::MessageType::EmbeddingRequest(common_types::EmbeddingRequest {
|
||||
create_embedding_request,
|
||||
prompt_target,
|
||||
}) => {
|
||||
info!("response received for CreateEmbeddingRequest");
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
let embedding_response: common_types::CreateEmbeddingResponse =
|
||||
serde_json::from_slice(&body).unwrap();
|
||||
info!(
|
||||
"embedding_response model: {}, vector len: {}",
|
||||
embedding_response.model,
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
payload.insert(
|
||||
"prompt-target".to_string(),
|
||||
to_string(&prompt_target).unwrap(),
|
||||
);
|
||||
payload.insert(
|
||||
"few-shot-example".to_string(),
|
||||
create_embedding_request.input.clone(),
|
||||
);
|
||||
|
||||
let id = md5::compute(create_embedding_request.input);
|
||||
|
||||
let create_vector_store_points = common_types::CreateVectorStorePoints {
|
||||
points: vec![common_types::VectorPoint {
|
||||
id: format!("{:x}", id),
|
||||
payload,
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
}],
|
||||
};
|
||||
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
|
||||
info!(
|
||||
"create_vector_store_points: points length: {}",
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store/points"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
|
||||
let callout_message = common_types::CalloutData {
|
||||
message: common_types::MessageType::CreateVectorStorePoints(
|
||||
create_vector_store_points,
|
||||
),
|
||||
};
|
||||
if self.callouts.insert(token_id, callout_message).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
common_types::MessageType::CreateVectorStorePoints(_) => {
|
||||
info!("response received for CreateVectorStorePoints");
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
info!("response body: {:?}", String::from_utf8(body).unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for HttpHeaderRoot {
|
||||
fn on_configure(&mut self, plugin_configuration_size: usize) -> bool {
|
||||
info!(
|
||||
"on_configure: plugin_configuration_size is {}",
|
||||
plugin_configuration_size
|
||||
);
|
||||
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
let config_str = String::from_utf8(config_bytes).unwrap();
|
||||
info!("on_configure: plugin configuration is {:?}", config_str);
|
||||
self.config = serde_yaml::from_str(&config_str).unwrap();
|
||||
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
|
||||
info!("on_configure: plugin configuration loaded");
|
||||
info!("on_configure: {:?}", self.config);
|
||||
}
|
||||
true
|
||||
}
|
||||
|
|
@ -151,4 +244,56 @@ impl RootContext for HttpHeaderRoot {
|
|||
fn get_type(&self) -> Option<ContextType> {
|
||||
Some(ContextType::HttpContext)
|
||||
}
|
||||
|
||||
fn on_vm_start(&mut self, _: usize) -> bool {
|
||||
info!("on_vm_start: setting up tick timeout");
|
||||
self.set_tick_period(Duration::from_secs(1));
|
||||
true
|
||||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
info!("on_tick: starting to process prompt targets");
|
||||
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
|
||||
for few_shot_example in &prompt_target.few_shot_examples {
|
||||
info!("few_shot_example: {:?}", few_shot_example);
|
||||
let embeddings_input = common_types::CreateEmbeddingRequest {
|
||||
input: few_shot_example.to_string(),
|
||||
model: String::from(consts::DEFAULT_EMBEDDING_MODEL),
|
||||
};
|
||||
|
||||
// TODO: Handle potential errors
|
||||
let json_data = to_string(&embeddings_input).unwrap();
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
let embedding_request = EmbeddingRequest {
|
||||
create_embedding_request: embeddings_input,
|
||||
prompt_target: prompt_target.clone(),
|
||||
};
|
||||
let callout_message = common_types::CalloutData {
|
||||
message: common_types::MessageType::EmbeddingRequest(embedding_request),
|
||||
};
|
||||
if self.callouts.insert(token_id, callout_message).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
}
|
||||
}
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@
|
|||
"name": "envoyfilter",
|
||||
"path": "envoyfilter"
|
||||
},
|
||||
{
|
||||
"name": "embedding-server",
|
||||
"path": "embedding-server"
|
||||
},
|
||||
{
|
||||
"name": "demos",
|
||||
"path": "./demos"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue