add embedding store (#10)

This commit is contained in:
Adil Hafeez 2024-07-18 14:04:51 -07:00 committed by GitHub
parent cc2a496f90
commit 7bf77afa0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 409 additions and 11 deletions

View file

@ -0,0 +1,54 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::configuration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateEmbeddingRequest {
pub input: String,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateEmbeddingResponse {
pub object: String,
pub model: String,
pub data: Vec<Embedding>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Embedding {
pub object: String,
pub index: i32,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub create_embedding_request: CreateEmbeddingRequest,
pub prompt_target: configuration::PromptTarget,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MessageType {
EmbeddingRequest(EmbeddingRequest),
CreateVectorStorePoints(CreateVectorStorePoints),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalloutData {
pub message: MessageType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorPoint {
pub id: String,
pub payload: HashMap<String, String>,
pub vector: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateVectorStorePoints {
pub points: Vec<VectorPoint>,
}

View file

@ -29,6 +29,7 @@ pub struct PromptConfig {
pub embedding_provider: EmbeddingProviver,
pub llm_providers: Vec<LlmProvider>,
pub system_prompt: String,
pub prompt_targets: Vec<PromptTarget>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -0,0 +1 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";

View file

@ -1,9 +1,14 @@
mod common_types;
mod configuration;
mod consts;
use common_types::EmbeddingRequest;
use log::info;
use serde_json::to_string;
use stats::IncrementingMetric;
use stats::Metric;
use stats::RecordingMetric;
use std::collections::HashMap;
use std::time::Duration;
use proxy_wasm::traits::*;
@ -15,6 +20,7 @@ proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(HttpHeaderRoot {
callouts: HashMap::new(),
config: None,
metrics: WasmMetrics {
counter: stats::Counter::new(String::from("wasm_counter")),
@ -86,6 +92,7 @@ impl HttpContext for HttpHeader {
impl Context for HttpHeader {
// Note that the event driven model continues here from the return of the on_http_request_headers above.
fn on_http_call_response(&mut self, _: u32, _: usize, body_size: usize, _: usize) {
info!("on_http_call_response: body_size = {}", body_size);
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() && body[0] % 2 == 0 {
info!("Access granted.");
@ -117,25 +124,111 @@ struct WasmMetrics {
struct HttpHeaderRoot {
metrics: WasmMetrics,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, common_types::CalloutData>,
config: Option<configuration::Configuration>,
}
impl Context for HttpHeaderRoot {}
impl Context for HttpHeaderRoot {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
info!("on_http_call_response: token_id = {}", token_id);
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
match callout_data.message {
common_types::MessageType::EmbeddingRequest(common_types::EmbeddingRequest {
create_embedding_request,
prompt_target,
}) => {
info!("response received for CreateEmbeddingRequest");
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
let embedding_response: common_types::CreateEmbeddingResponse =
serde_json::from_slice(&body).unwrap();
info!(
"embedding_response model: {}, vector len: {}",
embedding_response.model,
embedding_response.data[0].embedding.len()
);
let mut payload: HashMap<String, String> = HashMap::new();
payload.insert(
"prompt-target".to_string(),
to_string(&prompt_target).unwrap(),
);
payload.insert(
"few-shot-example".to_string(),
create_embedding_request.input.clone(),
);
let id = md5::compute(create_embedding_request.input);
let create_vector_store_points = common_types::CreateVectorStorePoints {
points: vec![common_types::VectorPoint {
id: format!("{:x}", id),
payload,
vector: embedding_response.data[0].embedding.clone(),
}],
};
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
info!(
"create_vector_store_points: points length: {}",
embedding_response.data[0].embedding.len()
);
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store/points"),
(":authority", "qdrant"),
("content-type", "application/json"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
let callout_message = common_types::CalloutData {
message: common_types::MessageType::CreateVectorStorePoints(
create_vector_store_points,
),
};
if self.callouts.insert(token_id, callout_message).is_some() {
panic!("duplicate token_id")
}
}
}
}
common_types::MessageType::CreateVectorStorePoints(_) => {
info!("response received for CreateVectorStorePoints");
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
info!("response body: {:?}", String::from_utf8(body).unwrap());
}
}
}
}
}
}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for HttpHeaderRoot {
fn on_configure(&mut self, plugin_configuration_size: usize) -> bool {
info!(
"on_configure: plugin_configuration_size is {}",
plugin_configuration_size
);
fn on_configure(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_plugin_configuration() {
let config_str = String::from_utf8(config_bytes).unwrap();
info!("on_configure: plugin configuration is {:?}", config_str);
self.config = serde_yaml::from_str(&config_str).unwrap();
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
info!("on_configure: plugin configuration loaded");
info!("on_configure: {:?}", self.config);
}
true
}
@ -151,4 +244,56 @@ impl RootContext for HttpHeaderRoot {
fn get_type(&self) -> Option<ContextType> {
Some(ContextType::HttpContext)
}
fn on_vm_start(&mut self, _: usize) -> bool {
info!("on_vm_start: setting up tick timeout");
self.set_tick_period(Duration::from_secs(1));
true
}
fn on_tick(&mut self) {
info!("on_tick: starting to process prompt targets");
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
for few_shot_example in &prompt_target.few_shot_examples {
info!("few_shot_example: {:?}", few_shot_example);
let embeddings_input = common_types::CreateEmbeddingRequest {
input: few_shot_example.to_string(),
model: String::from(consts::DEFAULT_EMBEDDING_MODEL),
};
// TODO: Handle potential errors
let json_data = to_string(&embeddings_input).unwrap();
let token_id = match self.dispatch_http_call(
"embeddingserver",
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
let embedding_request = EmbeddingRequest {
create_embedding_request: embeddings_input,
prompt_target: prompt_target.clone(),
};
let callout_message = common_types::CalloutData {
message: common_types::MessageType::EmbeddingRequest(embedding_request),
};
if self.callouts.insert(token_id, callout_message).is_some() {
panic!("duplicate token_id")
}
}
}
self.set_tick_period(Duration::from_secs(0));
}
}