diff --git a/envoyfilter/Cargo.lock b/envoyfilter/Cargo.lock index e9e129a6..93c9e116 100644 --- a/envoyfilter/Cargo.lock +++ b/envoyfilter/Cargo.lock @@ -511,6 +511,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -518,6 +533,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -526,6 +542,23 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + [[package]] name = "futures-sink" version = "0.3.30" @@ -544,10 +577,15 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -822,6 +860,7 @@ dependencies = [ name = "intelligent-prompt-gateway" version = "0.1.0" dependencies = [ + "http", "log", "md5", "open-message-format", @@ -830,6 +869,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "serial_test", ] [[package]] @@ -925,6 +965,16 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.22" @@ -1112,6 +1162,29 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.6", +] + [[package]] name = "paste" version = "1.0.15" @@ -1293,6 +1366,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_syscall" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" +dependencies = [ + "bitflags 2.6.0", +] + [[package]] name = "redox_users" version = "0.4.5" @@ -1447,6 +1529,15 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "scc" +version = "2.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fadf67e3cf23f8b11a6c8c48a16cb2437381503615acd91094ec7b4686a5a53" +dependencies = [ + "sdd", +] + [[package]] name = "schannel" version = "0.1.23" @@ -1456,6 +1547,18 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "sdd" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85f05a494052771fc5bd0619742363b5e24e5ad72ab3111ec2e27925b8edc5f3" + [[package]] name = "security-framework" version = "2.11.1" @@ -1550,6 +1653,31 @@ dependencies = [ "unsafe-libyaml", ] +[[package]] +name = "serial_test" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b4b487fe2acf240a021cf57c6b2b4903b1e78ca0ecd862a71b71d2a51fed77d" +dependencies = [ + "futures", + "log", + "once_cell", + "parking_lot", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + [[package]] name = "sha2" version = "0.10.8" diff --git a/envoyfilter/Cargo.toml b/envoyfilter/Cargo.toml index b208f577..f23e8fbb 100644 --- a/envoyfilter/Cargo.toml +++ b/envoyfilter/Cargo.toml @@ -15,6 +15,8 @@ serde_yaml = "0.9.34" serde_json = "1.0" md5 = "0.7.0" open-message-format = { path = "../open-message-format/clients/omf-rust" } +http = "1.1.0" [dev-dependencies] proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" } +serial_test = "3.1.1" diff --git a/envoyfilter/src/lib.rs b/envoyfilter/src/lib.rs index 49d22ebf..088e70e9 100644 --- a/envoyfilter/src/lib.rs +++ b/envoyfilter/src/lib.rs @@ -1,17 +1,17 @@ use common_types::{CallContext, EmbeddingRequest}; use configuration::PromptTarget; +use http::StatusCode; +use log::error; use log::info; use md5::Digest; use open_message_format::models::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; -use serde_json::to_string; -use stats::RecordingMetric; -use std::collections::HashMap; -use std::time::Duration; - use proxy_wasm::traits::*; use proxy_wasm::types::*; +use stats::{Gauge, RecordingMetric}; +use std::collections::HashMap; +use std::time::Duration; mod common_types; mod configuration; @@ -85,21 +85,47 @@ impl HttpContext for StreamContext { match self.get_http_request_body(0, body_size) { Some(body_bytes) => match serde_json::from_slice(&body_bytes) { Ok(deserialized) => deserialized, - Err(msg) => panic!("Failed to deserialize: {}", msg), + Err(msg) => { + self.send_http_response( + StatusCode::BAD_REQUEST.as_u16().into(), + vec![], + Some(format!("Failed to deserialize: {}", msg).as_bytes()), + ); + return Action::Pause; + } }, - None => panic!( - "Failed to obtain body bytes even though body_size is {}", - body_size - ), + None => { + self.send_http_response( + StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), + vec![], + None, + ); + error!( + "Failed to obtain body bytes even though body_size is {}", + body_size + ); + return Action::Pause; + } }; // Modify JSON payload deserialized_body.model = String::from("gpt-3.5-turbo"); - let json_string = serde_json::to_string(&deserialized_body).unwrap(); - self.set_http_request_body(0, body_size, &json_string.into_bytes()); - - Action::Continue + match serde_json::to_string(&deserialized_body) { + Ok(json_string) => { + self.set_http_request_body(0, body_size, &json_string.into_bytes()); + Action::Continue + } + Err(error) => { + self.send_http_response( + StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), + vec![], + None, + ); + error!("Failed to serialize body: {}", error); + Action::Pause + } + } } } @@ -107,13 +133,13 @@ impl Context for StreamContext {} #[derive(Copy, Clone)] struct WasmMetrics { - active_http_calls: stats::Gauge, + active_http_calls: Gauge, } impl WasmMetrics { fn new() -> WasmMetrics { WasmMetrics { - active_http_calls: stats::Gauge::new(String::from("active_http_calls")), + active_http_calls: Gauge::new(String::from("active_http_calls")), } } } @@ -135,7 +161,13 @@ impl FilterContext { } fn process_prompt_targets(&mut self) { - for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets { + for prompt_target in &self + .config + .as_ref() + .expect("Gateway configuration cannot be non-existent") + .prompt_config + .prompt_targets + { for few_shot_example in &prompt_target.few_shot_examples { info!("few_shot_example: {:?}", few_shot_example); @@ -149,8 +181,12 @@ impl FilterContext { user: None, }; - // TODO: Handle potential errors - let json_data: String = to_string(&embeddings_input).unwrap(); + let json_data: String = match serde_json::to_string(&embeddings_input) { + Ok(json_data) => json_data, + Err(error) => { + panic!("Error serializing embeddings input: {}", error); + } + }; let token_id = match self.dispatch_http_call( "embeddingserver", @@ -159,6 +195,7 @@ impl FilterContext { (":path", "/embeddings"), (":authority", "embeddingserver"), ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), ], Some(json_data.as_bytes()), vec![], @@ -166,7 +203,7 @@ impl FilterContext { ) { Ok(token_id) => token_id, Err(e) => { - panic!("Error dispatching HTTP call: {:?}", e); + panic!("Error dispatching embedding server HTTP call: {:?}", e); } }; info!("on_tick: dispatched HTTP call with token_id = {}", token_id); @@ -181,7 +218,10 @@ impl FilterContext { }) .is_some() { - panic!("duplicate token_id") + panic!( + "duplicate token_id={} in embedding server requests", + token_id + ) } self.metrics .active_http_calls @@ -197,78 +237,65 @@ impl FilterContext { prompt_target: PromptTarget, ) { info!("response received for CreateEmbeddingRequest"); - if let Some(body) = self.get_http_call_response_body(0, body_size) { - if !body.is_empty() { - let embedding_response: CreateEmbeddingResponse = - serde_json::from_slice(&body).unwrap(); - info!( - "embedding_response model: {}, vector len: {}", - embedding_response.model, - embedding_response.data[0].embedding.len() - ); - - let mut payload: HashMap = HashMap::new(); - payload.insert( - "prompt-target".to_string(), - to_string(&prompt_target).unwrap(), - ); - let id: Option; - match *create_embedding_request.input { - CreateEmbeddingRequestInput::String(input) => { - id = Some(md5::compute(&input)); - payload.insert("input".to_string(), input); - } - CreateEmbeddingRequestInput::Array(_) => todo!(), - } - - let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest { - points: vec![common_types::VectorPoint { - id: format!("{:x}", id.unwrap()), - payload, - vector: embedding_response.data[0].embedding.clone(), - }], - }; - let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors - info!( - "create_vector_store_points: points length: {}", - embedding_response.data[0].embedding.len() - ); - let token_id = match self.dispatch_http_call( - "qdrant", - vec![ - (":method", "PUT"), - (":path", "/collections/prompt_vector_store/points"), - (":authority", "qdrant"), - ("content-type", "application/json"), - ], - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ) { - Ok(token_id) => token_id, - Err(e) => { - panic!("Error dispatching HTTP call: {:?}", e); - } - }; - info!("on_tick: dispatched HTTP call with token_id = {}", token_id); - - if self - .callouts - .insert( - token_id, - CallContext::StoreVectorEmbeddings(create_vector_store_points), - ) - .is_some() - { - panic!("duplicate token_id") - } - self.metrics - .active_http_calls - .record(self.callouts.len().try_into().unwrap()); + let body = match self.get_http_call_response_body(0, body_size) { + Some(body) => body, + None => { + return; } + }; + + if body.is_empty() { + return; } + + let (json_data, create_vector_store_points) = + match build_qdrant_data(&body, create_embedding_request, &prompt_target) { + Ok(tuple) => tuple, + Err(error) => { + panic!( + "Error building qdrant data from embedding response {}", + error + ); + } + }; + + let token_id = match self.dispatch_http_call( + "qdrant", + vec![ + (":method", "PUT"), + (":path", "/collections/prompt_vector_store/points"), + (":authority", "qdrant"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ], + Some(json_data.as_bytes()), + vec![], + Duration::from_secs(5), + ) { + Ok(token_id) => token_id, + Err(e) => { + panic!("Error dispatching qdrant HTTP call: {:?}", e); + } + }; + info!("on_tick: dispatched HTTP call with token_id = {}", token_id); + + if self + .callouts + .insert( + token_id, + CallContext::StoreVectorEmbeddings(create_vector_store_points), + ) + .is_some() + { + panic!("duplicate token_id={} in qdrant requests", token_id) + } + + self.metrics + .active_http_calls + .record(self.callouts.len().try_into().unwrap()); } + // TODO: @adilhafeez implement. fn create_vector_store_points_handler(&self, body_size: usize) { info!("response received for CreateVectorStorePoints"); if let Some(body) = self.get_http_call_response_body(0, body_size) { @@ -294,7 +321,6 @@ impl Context for FilterContext { self.metrics .active_http_calls .record(self.callouts.len().try_into().unwrap()); - match callout_data { common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest { create_embedding_request, @@ -313,7 +339,12 @@ impl Context for FilterContext { impl RootContext for FilterContext { fn on_configure(&mut self, _: usize) -> bool { if let Some(config_bytes) = self.get_plugin_configuration() { - self.config = serde_yaml::from_slice(&config_bytes).unwrap(); + self.config = match serde_yaml::from_slice(&config_bytes) { + Ok(config) => config, + Err(error) => { + panic!("Failed to deserialize plugin configuration: {}", error); + } + }; info!("on_configure: plugin configuration loaded"); } true @@ -339,3 +370,51 @@ impl RootContext for FilterContext { self.set_tick_period(Duration::from_secs(0)); } } + +fn build_qdrant_data( + embedding_data: &[u8], + create_embedding_request: CreateEmbeddingRequest, + prompt_target: &PromptTarget, +) -> Result<(String, common_types::StoreVectorEmbeddingsRequest), serde_json::Error> { + let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(embedding_data) { + Ok(embedding_response) => embedding_response, + Err(error) => { + panic!("Failed to deserialize embedding response: {}", error); + } + }; + info!( + "embedding_response model: {}, vector len: {}", + embedding_response.model, + embedding_response.data[0].embedding.len() + ); + + let mut payload: HashMap = HashMap::new(); + payload.insert( + "prompt-target".to_string(), + serde_json::to_string(&prompt_target)?, + ); + + let id: Option; + match *create_embedding_request.input { + CreateEmbeddingRequestInput::String(ref input) => { + id = Some(md5::compute(input)); + payload.insert("input".to_string(), input.clone()); + } + CreateEmbeddingRequestInput::Array(_) => todo!(), + } + + let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest { + points: vec![common_types::VectorPoint { + id: format!("{:x}", id.unwrap()), + payload, + vector: embedding_response.data[0].embedding.clone(), + }], + }; + let json_data = serde_json::to_string(&create_vector_store_points)?; + info!( + "create_vector_store_points: points length: {}", + embedding_response.data[0].embedding.len() + ); + + Ok((json_data, create_vector_store_points)) +} diff --git a/envoyfilter/src/stats.rs b/envoyfilter/src/stats.rs index 27d2f413..693f24b5 100644 --- a/envoyfilter/src/stats.rs +++ b/envoyfilter/src/stats.rs @@ -1,14 +1,15 @@ +use log::error; use proxy_wasm::hostcalls; use proxy_wasm::types::*; #[allow(unused)] pub trait Metric { fn id(&self) -> u32; - fn value(&self) -> u64 { + fn value(&self) -> Result { match hostcalls::get_metric(self.id()) { - Ok(value) => value, - Err(Status::NotFound) => panic!("metric not found: {}", self.id()), - Err(err) => panic!("unexpected status: {:?}", err), + Ok(value) => Ok(value), + Err(Status::NotFound) => Err(format!("metric not found: {}", self.id())), + Err(err) => Err(format!("unexpected status: {:?}", err)), } } } @@ -17,9 +18,8 @@ pub trait Metric { pub trait IncrementingMetric: Metric { fn increment(&self, offset: i64) { match hostcalls::increment_metric(self.id(), offset) { - Ok(data) => data, - Err(Status::NotFound) => panic!("metric not found: {}", self.id()), - Err(err) => panic!("unexpected status: {:?}", err), + Ok(_) => (), + Err(err) => error!("error incrementing metric: {:?}", err), } } } @@ -27,9 +27,8 @@ pub trait IncrementingMetric: Metric { pub trait RecordingMetric: Metric { fn record(&self, value: u64) { match hostcalls::record_metric(self.id(), value) { - Ok(data) => data, - Err(Status::NotFound) => panic!("metric not found: {}", self.id()), - Err(err) => panic!("unexpected status: {:?}", err), + Ok(_) => (), + Err(err) => error!("error recording metric: {:?}", err), } } } diff --git a/envoyfilter/tests/integration.rs b/envoyfilter/tests/integration.rs index c044a59d..6dfb7675 100644 --- a/envoyfilter/tests/integration.rs +++ b/envoyfilter/tests/integration.rs @@ -1,5 +1,7 @@ +use http::StatusCode; use proxy_wasm_test_framework::tester; use proxy_wasm_test_framework::types::{Action, BufferType, MapType, MetricType, ReturnType}; +use serial_test::serial; use std::path::Path; fn wasm_module() -> String { @@ -12,7 +14,8 @@ fn wasm_module() -> String { } #[test] -fn request_to_open_ai_chat_completions() { +#[serial] +fn successful_request_to_open_ai_chat_completions() { let args = tester::MockSettings { wasm_path: wasm_module(), quiet: false, @@ -90,3 +93,87 @@ fn request_to_open_ai_chat_completions() { .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); } + +#[test] +#[serial] +fn bad_request_to_open_ai_chat_completions() { + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); + + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); + + // Setup Filter + let root_context = 1; + + module + .call_proxy_on_context_create(root_context, 0) + .expect_metric_creation(MetricType::Gauge, "active_http_calls") + .execute_and_expect(ReturnType::None) + .unwrap(); + + // Setup HTTP Stream + let http_context = 2; + + module + .call_proxy_on_context_create(http_context, root_context) + .execute_and_expect(ReturnType::None) + .unwrap(); + + // Request Headers + module + .call_proxy_on_request_headers(http_context, 0, false) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host")) + .returning(Some("api.openai.com")) + .expect_add_header_map_value( + Some(MapType::HttpRequestHeaders), + Some("content-length"), + Some(""), + ) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) + .returning(Some("/llmrouting")) + .expect_add_header_map_value( + Some(MapType::HttpRequestHeaders), + Some(":path"), + Some("/v1/chat/completions"), + ) + .execute_and_expect(ReturnType::Action(Action::Continue)) + .unwrap(); + + // Request Body + let incomplete_chat_completions_request_body = "\ + {\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + }\ + ]\ + }"; + + module + .call_proxy_on_request_body( + http_context, + incomplete_chat_completions_request_body.len() as i32, + true, + ) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(incomplete_chat_completions_request_body)) + .expect_send_local_response( + Some(StatusCode::BAD_REQUEST.as_u16().into()), + None, + None, + None, + ) + .execute_and_expect(ReturnType::Action(Action::Pause)) + .unwrap(); +}