mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 04:42:49 +02:00
add embedding store (#10)
This commit is contained in:
parent
cc2a496f90
commit
7bf77afa0e
16 changed files with 409 additions and 11 deletions
54
envoyfilter/src/common_types.rs
Normal file
54
envoyfilter/src/common_types.rs
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::configuration;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingRequest {
|
||||
pub input: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingResponse {
|
||||
pub object: String,
|
||||
pub model: String,
|
||||
pub data: Vec<Embedding>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Embedding {
|
||||
pub object: String,
|
||||
pub index: i32,
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
pub create_embedding_request: CreateEmbeddingRequest,
|
||||
pub prompt_target: configuration::PromptTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum MessageType {
|
||||
EmbeddingRequest(EmbeddingRequest),
|
||||
CreateVectorStorePoints(CreateVectorStorePoints),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CalloutData {
|
||||
pub message: MessageType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorPoint {
|
||||
pub id: String,
|
||||
pub payload: HashMap<String, String>,
|
||||
pub vector: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateVectorStorePoints {
|
||||
pub points: Vec<VectorPoint>,
|
||||
}
|
||||
|
|
@ -29,6 +29,7 @@ pub struct PromptConfig {
|
|||
pub embedding_provider: EmbeddingProviver,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub system_prompt: String,
|
||||
pub prompt_targets: Vec<PromptTarget>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
|
|||
1
envoyfilter/src/consts.rs
Normal file
1
envoyfilter/src/consts.rs
Normal file
|
|
@ -0,0 +1 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
|
||||
|
|
@ -1,9 +1,14 @@
|
|||
mod common_types;
|
||||
mod configuration;
|
||||
mod consts;
|
||||
|
||||
use common_types::EmbeddingRequest;
|
||||
use log::info;
|
||||
use serde_json::to_string;
|
||||
use stats::IncrementingMetric;
|
||||
use stats::Metric;
|
||||
use stats::RecordingMetric;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use proxy_wasm::traits::*;
|
||||
|
|
@ -15,6 +20,7 @@ proxy_wasm::main! {{
|
|||
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
|
||||
Box::new(HttpHeaderRoot {
|
||||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
metrics: WasmMetrics {
|
||||
counter: stats::Counter::new(String::from("wasm_counter")),
|
||||
|
|
@ -86,6 +92,7 @@ impl HttpContext for HttpHeader {
|
|||
impl Context for HttpHeader {
|
||||
// Note that the event driven model continues here from the return of the on_http_request_headers above.
|
||||
fn on_http_call_response(&mut self, _: u32, _: usize, body_size: usize, _: usize) {
|
||||
info!("on_http_call_response: body_size = {}", body_size);
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() && body[0] % 2 == 0 {
|
||||
info!("Access granted.");
|
||||
|
|
@ -117,25 +124,111 @@ struct WasmMetrics {
|
|||
|
||||
struct HttpHeaderRoot {
|
||||
metrics: WasmMetrics,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, common_types::CalloutData>,
|
||||
config: Option<configuration::Configuration>,
|
||||
}
|
||||
|
||||
impl Context for HttpHeaderRoot {}
|
||||
impl Context for HttpHeaderRoot {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
info!("on_http_call_response: token_id = {}", token_id);
|
||||
|
||||
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
|
||||
match callout_data.message {
|
||||
common_types::MessageType::EmbeddingRequest(common_types::EmbeddingRequest {
|
||||
create_embedding_request,
|
||||
prompt_target,
|
||||
}) => {
|
||||
info!("response received for CreateEmbeddingRequest");
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
let embedding_response: common_types::CreateEmbeddingResponse =
|
||||
serde_json::from_slice(&body).unwrap();
|
||||
info!(
|
||||
"embedding_response model: {}, vector len: {}",
|
||||
embedding_response.model,
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
payload.insert(
|
||||
"prompt-target".to_string(),
|
||||
to_string(&prompt_target).unwrap(),
|
||||
);
|
||||
payload.insert(
|
||||
"few-shot-example".to_string(),
|
||||
create_embedding_request.input.clone(),
|
||||
);
|
||||
|
||||
let id = md5::compute(create_embedding_request.input);
|
||||
|
||||
let create_vector_store_points = common_types::CreateVectorStorePoints {
|
||||
points: vec![common_types::VectorPoint {
|
||||
id: format!("{:x}", id),
|
||||
payload,
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
}],
|
||||
};
|
||||
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
|
||||
info!(
|
||||
"create_vector_store_points: points length: {}",
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store/points"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
|
||||
let callout_message = common_types::CalloutData {
|
||||
message: common_types::MessageType::CreateVectorStorePoints(
|
||||
create_vector_store_points,
|
||||
),
|
||||
};
|
||||
if self.callouts.insert(token_id, callout_message).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
common_types::MessageType::CreateVectorStorePoints(_) => {
|
||||
info!("response received for CreateVectorStorePoints");
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
info!("response body: {:?}", String::from_utf8(body).unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for HttpHeaderRoot {
|
||||
fn on_configure(&mut self, plugin_configuration_size: usize) -> bool {
|
||||
info!(
|
||||
"on_configure: plugin_configuration_size is {}",
|
||||
plugin_configuration_size
|
||||
);
|
||||
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
let config_str = String::from_utf8(config_bytes).unwrap();
|
||||
info!("on_configure: plugin configuration is {:?}", config_str);
|
||||
self.config = serde_yaml::from_str(&config_str).unwrap();
|
||||
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
|
||||
info!("on_configure: plugin configuration loaded");
|
||||
info!("on_configure: {:?}", self.config);
|
||||
}
|
||||
true
|
||||
}
|
||||
|
|
@ -151,4 +244,56 @@ impl RootContext for HttpHeaderRoot {
|
|||
fn get_type(&self) -> Option<ContextType> {
|
||||
Some(ContextType::HttpContext)
|
||||
}
|
||||
|
||||
fn on_vm_start(&mut self, _: usize) -> bool {
|
||||
info!("on_vm_start: setting up tick timeout");
|
||||
self.set_tick_period(Duration::from_secs(1));
|
||||
true
|
||||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
info!("on_tick: starting to process prompt targets");
|
||||
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
|
||||
for few_shot_example in &prompt_target.few_shot_examples {
|
||||
info!("few_shot_example: {:?}", few_shot_example);
|
||||
let embeddings_input = common_types::CreateEmbeddingRequest {
|
||||
input: few_shot_example.to_string(),
|
||||
model: String::from(consts::DEFAULT_EMBEDDING_MODEL),
|
||||
};
|
||||
|
||||
// TODO: Handle potential errors
|
||||
let json_data = to_string(&embeddings_input).unwrap();
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
let embedding_request = EmbeddingRequest {
|
||||
create_embedding_request: embeddings_input,
|
||||
prompt_target: prompt_target.clone(),
|
||||
};
|
||||
let callout_message = common_types::CalloutData {
|
||||
message: common_types::MessageType::EmbeddingRequest(embedding_request),
|
||||
};
|
||||
if self.callouts.insert(token_id, callout_message).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
}
|
||||
}
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue