diff --git a/envoyfilter/Cargo.lock b/envoyfilter/Cargo.lock index 34142f27..c1f5edfe 100644 --- a/envoyfilter/Cargo.lock +++ b/envoyfilter/Cargo.lock @@ -266,6 +266,12 @@ dependencies = [ "allocator-api2", ] +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + [[package]] name = "http" version = "1.1.0" @@ -490,13 +496,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.11" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" dependencies = [ + "hermit-abi", "libc", "wasi", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -738,9 +745,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.11" +version = "0.23.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4828ea528154ae444e5a642dbb7d5623354030dc9822b83fd9bb79683c7399d0" +checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044" dependencies = [ "once_cell", "rustls-pki-types", @@ -974,9 +981,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.38.1" +version = "1.39.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb2caba9f80616f438e09748d5acda951967e1ea58508ef53d9c6402485a46df" +checksum = "d040ac2b29ab03b09d4129c2f5bbd012a3ac2f79d38ff506a4bf8dd34b0eac8a" dependencies = [ "backtrace", "bytes", @@ -984,7 +991,7 @@ dependencies = [ "mio", "pin-project-lite", "socket2", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] diff --git a/envoyfilter/src/common_types.rs b/envoyfilter/src/common_types.rs index c80475f8..4ee8f7fa 100644 --- a/envoyfilter/src/common_types.rs +++ b/envoyfilter/src/common_types.rs @@ -1,7 +1,6 @@ use open_message_format::models::CreateEmbeddingRequest; -use std::collections::HashMap; - use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use crate::configuration; @@ -11,17 +10,6 @@ pub struct EmbeddingRequest { pub prompt_target: configuration::PromptTarget, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum MessageType { - EmbeddingRequest(EmbeddingRequest), - CreateVectorStorePoints(CreateVectorStorePoints), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CalloutData { - pub message: MessageType, -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct VectorPoint { pub id: String, @@ -30,10 +18,16 @@ pub struct VectorPoint { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CreateVectorStorePoints { +pub struct StoreVectorEmbeddingsRequest { pub points: Vec, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum CallContext { + EmbeddingRequest(EmbeddingRequest), + StoreVectorEmbeddings(StoreVectorEmbeddingsRequest), +} + pub mod open_ai { use serde::{Deserialize, Serialize}; diff --git a/envoyfilter/src/lib.rs b/envoyfilter/src/lib.rs index 86a05ed4..49d22ebf 100644 --- a/envoyfilter/src/lib.rs +++ b/envoyfilter/src/lib.rs @@ -1,4 +1,4 @@ -use common_types::EmbeddingRequest; +use common_types::{CallContext, EmbeddingRequest}; use configuration::PromptTarget; use log::info; use md5::Digest; @@ -6,8 +6,6 @@ use open_message_format::models::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; use serde_json::to_string; -use stats::IncrementingMetric; -use stats::Metric; use stats::RecordingMetric; use std::collections::HashMap; use std::time::Duration; @@ -109,33 +107,21 @@ impl Context for StreamContext {} #[derive(Copy, Clone)] struct WasmMetrics { - // TODO: remove example metrics once real ones are created. These ones exist in order to silence the dead code warning. - example_counter: stats::Counter, - example_gauge: stats::Gauge, - example_histogram: stats::Histogram, + active_http_calls: stats::Gauge, } impl WasmMetrics { fn new() -> WasmMetrics { - let new_metrics = WasmMetrics { - example_counter: stats::Counter::new(String::from("example_counter")), - example_gauge: stats::Gauge::new(String::from("example_gauge")), - example_histogram: stats::Histogram::new(String::from("example_histogram")), - }; - - new_metrics.example_counter.increment(10); - new_metrics.example_counter.value(); - new_metrics.example_gauge.record(20); - new_metrics.example_histogram.record(30); - - new_metrics + WasmMetrics { + active_http_calls: stats::Gauge::new(String::from("active_http_calls")), + } } } struct FilterContext { - _metrics: WasmMetrics, + metrics: WasmMetrics, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. - callouts: HashMap, + callouts: HashMap, config: Option, } @@ -144,7 +130,7 @@ impl FilterContext { FilterContext { callouts: HashMap::new(), config: None, - _metrics: WasmMetrics::new(), + metrics: WasmMetrics::new(), } } @@ -164,7 +150,7 @@ impl FilterContext { }; // TODO: Handle potential errors - let json_data = to_string(&embeddings_input).unwrap(); + let json_data: String = to_string(&embeddings_input).unwrap(); let token_id = match self.dispatch_http_call( "embeddingserver", @@ -188,12 +174,18 @@ impl FilterContext { create_embedding_request: embeddings_input, prompt_target: prompt_target.clone(), }; - let callout_message = common_types::CalloutData { - message: common_types::MessageType::EmbeddingRequest(embedding_request), - }; - if self.callouts.insert(token_id, callout_message).is_some() { + if self + .callouts + .insert(token_id, { + CallContext::EmbeddingRequest(embedding_request) + }) + .is_some() + { panic!("duplicate token_id") } + self.metrics + .active_http_calls + .record(self.callouts.len().try_into().unwrap()); } } } @@ -229,7 +221,7 @@ impl FilterContext { CreateEmbeddingRequestInput::Array(_) => todo!(), } - let create_vector_store_points = common_types::CreateVectorStorePoints { + let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest { points: vec![common_types::VectorPoint { id: format!("{:x}", id.unwrap()), payload, @@ -260,14 +252,19 @@ impl FilterContext { }; info!("on_tick: dispatched HTTP call with token_id = {}", token_id); - let callout_message = common_types::CalloutData { - message: common_types::MessageType::CreateVectorStorePoints( - create_vector_store_points, - ), - }; - if self.callouts.insert(token_id, callout_message).is_some() { + if self + .callouts + .insert( + token_id, + CallContext::StoreVectorEmbeddings(create_vector_store_points), + ) + .is_some() + { panic!("duplicate token_id") } + self.metrics + .active_http_calls + .record(self.callouts.len().try_into().unwrap()); } } } @@ -294,14 +291,18 @@ impl Context for FilterContext { let callout_data = self.callouts.remove(&token_id).expect("invalid token_id"); - match callout_data.message { - common_types::MessageType::EmbeddingRequest(common_types::EmbeddingRequest { + self.metrics + .active_http_calls + .record(self.callouts.len().try_into().unwrap()); + + match callout_data { + common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest { create_embedding_request, prompt_target, }) => { self.embedding_request_handler(body_size, create_embedding_request, prompt_target) } - common_types::MessageType::CreateVectorStorePoints(_) => { + common_types::CallContext::StoreVectorEmbeddings(_) => { self.create_vector_store_points_handler(body_size) } } diff --git a/envoyfilter/src/stats.rs b/envoyfilter/src/stats.rs index 2322da01..27d2f413 100644 --- a/envoyfilter/src/stats.rs +++ b/envoyfilter/src/stats.rs @@ -1,9 +1,9 @@ use proxy_wasm::hostcalls; use proxy_wasm::types::*; +#[allow(unused)] pub trait Metric { fn id(&self) -> u32; - fn value(&self) -> u64 { match hostcalls::get_metric(self.id()) { Ok(value) => value, @@ -13,6 +13,7 @@ pub trait Metric { } } +#[allow(unused)] pub trait IncrementingMetric: Metric { fn increment(&self, offset: i64) { match hostcalls::increment_metric(self.id(), offset) { @@ -38,6 +39,7 @@ pub struct Counter { id: u32, } +#[allow(unused)] impl Counter { pub fn new(name: String) -> Counter { let returned_id = hostcalls::define_metric(MetricType::Counter, &name) @@ -80,6 +82,7 @@ pub struct Histogram { id: u32, } +#[allow(unused)] impl Histogram { pub fn new(name: String) -> Histogram { let returned_id = hostcalls::define_metric(MetricType::Histogram, &name) diff --git a/gateway.code-workspace b/gateway.code-workspace index 12273773..bc33ccde 100644 --- a/gateway.code-workspace +++ b/gateway.code-workspace @@ -12,6 +12,10 @@ "name": "embedding-server", "path": "embedding-server" }, + { + "name": "chatbot-client", + "path": "chatbot-client" + }, { "name": "open-message-format", "path": "open-message-format" diff --git a/open-message-format b/open-message-format index f70179ae..1469e1ec 160000 --- a/open-message-format +++ b/open-message-format @@ -1 +1 @@ -Subproject commit f70179ae3de7110cb40412c902b275b7900b40a1 +Subproject commit 1469e1ec49dd9b7fcf91e16fea4781294ced7660