From 60471286bb1650feba102bf9e010cbe58f5e3701 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Wed, 5 Feb 2025 17:39:44 -0800 Subject: [PATCH] Encode parameter values in http path and ... - don't send param values in request body in http get request - send param values in http post request --- crates/Cargo.lock | 302 ++++++++ crates/common/Cargo.toml | 2 + crates/common/src/path.rs | 163 ++++- crates/prompt_gateway/src/stream_context.rs | 81 ++- crates/prompt_gateway/tests/integration.rs | 722 ++++++++++---------- 5 files changed, 859 insertions(+), 411 deletions(-) diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 98157733..b585ef6e 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -234,6 +234,8 @@ dependencies = [ "serde_yaml", "thiserror", "tiktoken-rs", + "url", + "urlencoding", ] [[package]] @@ -477,6 +479,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "duration-string" version = "0.3.0" @@ -557,6 +570,15 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "futures" version = "0.3.31" @@ -782,12 +804,151 @@ dependencies = [ "itoa", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "id-arena" version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" +[[package]] +name = "idna" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + [[package]] name = "indexmap" version = "2.6.0" @@ -883,6 +1044,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "litemap" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" + [[package]] name = "llm_gateway" version = "0.1.0" @@ -1028,6 +1195,12 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + [[package]] name = "pin-project-lite" version = "0.2.14" @@ -1547,6 +1720,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "target-lexicon" version = "0.12.16" @@ -1606,6 +1790,16 @@ dependencies = [ "rustc-hash", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "toml" version = "0.8.19" @@ -1676,6 +1870,35 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" +[[package]] +name = "url" +version = "2.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "uuid" version = "1.11.0" @@ -2189,12 +2412,48 @@ dependencies = [ "wasmparser 0.212.0", ] +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "yansi" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" +[[package]] +name = "yoke" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.7.35" @@ -2216,6 +2475,49 @@ dependencies = [ "syn 2.0.79", ] +[[package]] +name = "zerofrom" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", + "synstructure", +] + +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "zstd" version = "0.13.2" diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index 84aa636c..d8c35140 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -16,6 +16,8 @@ tiktoken-rs = "0.5.9" rand = "0.8.5" serde_json = "1.0" hex = "0.4.3" +urlencoding = "2.1.3" +url = "2.5.4" [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/crates/common/src/path.rs b/crates/common/src/path.rs index 3bf2aed5..f11cb7b9 100644 --- a/crates/common/src/path.rs +++ b/crates/common/src/path.rs @@ -1,21 +1,30 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; +use url::Url; +use urlencoding; + +use crate::configuration::Parameter; pub fn replace_params_in_path( path: &str, - params: &HashMap, -) -> Result { - let mut result = String::new(); - let mut in_param = false; + tool_params: &HashMap, + prompt_target_params: &[Parameter], +) -> Result<(String, String, HashMap), String> { + let mut query_string_replaced = String::new(); let mut current_param = String::new(); + let mut vars_replaced = HashSet::new(); + let mut params: HashMap = HashMap::new(); + let mut in_param = false; for c in path.chars() { if c == '{' { in_param = true; } else if c == '}' { in_param = false; let param_name = current_param.clone(); - if let Some(value) = params.get(¶m_name) { - result.push_str(value); + if let Some(value) = tool_params.get(¶m_name) { + let value = urlencoding::encode(value); + query_string_replaced.push_str(value.into_owned().as_str()); + vars_replaced.insert(param_name.clone()); } else { return Err(format!("Missing value for parameter `{}`", param_name)); } @@ -23,31 +32,106 @@ pub fn replace_params_in_path( } else if in_param { current_param.push(c); } else { - result.push(c); + query_string_replaced.push(c); } } - Ok(result) + // add the remaining params in path + for (param_name, value) in tool_params.iter() { + let value = urlencoding::encode(value).into_owned(); + if !vars_replaced.contains(param_name) { + vars_replaced.insert(param_name.clone()); + params.insert(param_name.clone(), value.clone()); + if query_string_replaced.contains("?") { + query_string_replaced.push_str(&format!("&{}={}", param_name, value)); + } else { + query_string_replaced.push_str(&format!("?{}={}", param_name, value)); + } + } + } + + // add default values + for param in prompt_target_params.iter() { + if !vars_replaced.contains(¶m.name) && param.default.is_some() { + params.insert(param.name.clone(), param.default.clone().unwrap()); + if query_string_replaced.contains("?") { + query_string_replaced.push_str(&format!( + "&{}={}", + param.name, + param.default.as_ref().unwrap() + )); + } else { + query_string_replaced.push_str(&format!( + "?{}={}", + param.name, + param.default.as_ref().unwrap() + )); + } + } + } + + let parsed_uri = Url::parse("http://dummy.com").unwrap(); + let parsed_uri = parsed_uri + .join(&query_string_replaced) + .map_err(|e| e.to_string())?; + let query_string = parsed_uri.query().unwrap_or(""); + let path_uri = parsed_uri.path(); + + Ok((path_uri.to_string(), query_string.to_string(), params)) } #[cfg(test)] mod test { + use std::collections::HashMap; + + use crate::configuration::Parameter; + #[test] fn test_replace_path() { let path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}"; - let params = vec![("cluster_name".to_string(), "test1".to_string())] - .into_iter() - .collect(); + let params = vec![ + ("cluster_name".to_string(), "test1".to_string()), + ("hello".to_string(), "hello world".to_string()), + ] + .into_iter() + .collect(); + let prompt_target_params = vec![Parameter { + name: "country".to_string(), + parameter_type: None, + description: "test target".to_string(), + required: None, + enum_values: None, + default: Some("US".to_string()), + in_path: None, + format: None, + }]; + + let out_params: HashMap = vec![ + ("country".to_string(), "US".to_string()), + ("hello".to_string(), "hello%20world".to_string()), + ] + .into_iter() + .collect(); assert_eq!( - super::replace_params_in_path(path, ¶ms), - Ok("/cluster.open-cluster-management.io/v1/managedclusters/test1".to_string()) + super::replace_params_in_path(path, ¶ms, &prompt_target_params), + Ok(( + "/cluster.open-cluster-management.io/v1/managedclusters/test1".to_string(), + "hello=hello%20world&country=US".to_string(), + out_params.clone() + )) ); + let out_params = HashMap::new(); + let prompt_target_params = vec![]; let path = "/cluster.open-cluster-management.io/v1/managedclusters"; let params = vec![].into_iter().collect(); assert_eq!( - super::replace_params_in_path(path, ¶ms), - Ok("/cluster.open-cluster-management.io/v1/managedclusters".to_string()) + super::replace_params_in_path(path, ¶ms, &prompt_target_params), + Ok(( + "/cluster.open-cluster-management.io/v1/managedclusters".to_string(), + "".to_string(), + out_params + )) ); let path = "/foo/{bar}/baz"; @@ -55,8 +139,8 @@ mod test { .into_iter() .collect(); assert_eq!( - super::replace_params_in_path(path, ¶ms), - Ok("/foo/qux/baz".to_string()) + super::replace_params_in_path(path, ¶ms, &prompt_target_params), + Ok(("/foo/qux/baz".to_string(), "".to_string(), HashMap::new())) ); let path = "/foo/{bar}/baz/{qux}"; @@ -67,8 +151,45 @@ mod test { .into_iter() .collect(); assert_eq!( - super::replace_params_in_path(path, ¶ms), - Ok("/foo/qux/baz/quux".to_string()) + super::replace_params_in_path(path, ¶ms, &prompt_target_params), + Ok(( + "/foo/qux/baz/quux".to_string(), + "".to_string(), + HashMap::new() + )) + ); + + let path = "/foo/{bar}/baz/{qux}?hello=world"; + let params = vec![ + ("bar".to_string(), "qux".to_string()), + ("qux".to_string(), "quux".to_string()), + ] + .into_iter() + .collect(); + assert_eq!( + super::replace_params_in_path(path, ¶ms, &prompt_target_params), + Ok(( + "/foo/qux/baz/quux".to_string(), + "hello=world".to_string(), + HashMap::new() + )) + ); + + let path = "/foo/{bar}/baz/{qux}?hello={hello}"; + let params = vec![ + ("bar".to_string(), "qux".to_string()), + ("qux".to_string(), "quux".to_string()), + ("hello".to_string(), "hello world".to_string()), + ] + .into_iter() + .collect(); + assert_eq!( + super::replace_params_in_path(path, ¶ms, &prompt_target_params), + Ok(( + "/foo/qux/baz/quux".to_string(), + "hello=hello%20world".to_string(), + HashMap::new() + )) ); let path = "/foo/{bar}/baz/{qux}"; @@ -76,7 +197,7 @@ mod test { .into_iter() .collect(); assert_eq!( - super::replace_params_in_path(path, ¶ms), + super::replace_params_in_path(path, ¶ms, &prompt_target_params), Err("Missing value for parameter `qux`".to_string()) ); } diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 3e3fee5b..ce9b852f 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -3,7 +3,7 @@ use common::api::open_ai::{ to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message, ModelServerResponse, ToolCall, }; -use common::configuration::{Overrides, PromptTarget, Tracing}; +use common::configuration::{HttpMethod, Overrides, PromptTarget, Tracing}; use common::consts::{ ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, @@ -276,22 +276,18 @@ impl StreamContext { let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); - let mut tool_params = self.tool_calls.as_ref().unwrap()[0] + let tool_params = self.tool_calls.as_ref().unwrap()[0] .function .arguments .clone(); - tool_params.insert( - String::from(MESSAGES_KEY), - serde_yaml::to_value(&callout_context.request_body.messages).unwrap(), - ); - - let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); let endpoint = prompt_target.endpoint.unwrap(); let path: String = endpoint.path.unwrap_or(String::from("/")); + let prompt_target_params = prompt_target.parameters.unwrap_or_default(); + let http_method = endpoint.method.unwrap_or_default(); // only add params that are of string, number and bool type - let url_params = tool_params + let tool_url_params = tool_params .iter() .filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool()) .map(|(key, value)| match value { @@ -305,50 +301,77 @@ impl StreamContext { }) .collect::>(); - let path = match common::path::replace_params_in_path(&path, &url_params) { - Ok(path) => path, - Err(e) => { - return self.send_server_error( - ServerError::BadRequest { - why: format!("error replacing params in path: {}", e), - }, - Some(StatusCode::BAD_REQUEST), - ); + let (path_with_params, query_string, additional_params) = + match common::path::replace_params_in_path( + &path, + &tool_url_params, + &prompt_target_params, + ) { + Ok((path, query_string, additional_params)) => { + (path, query_string, additional_params) + } + Err(e) => { + return self.send_server_error( + ServerError::BadRequest { + why: format!("error replacing params in path: {}", e), + }, + Some(StatusCode::BAD_REQUEST), + ); + } + }; + + let (path, body) = match http_method { + HttpMethod::Get => { + (format!("{}?{}", path_with_params, query_string), None) + } + HttpMethod::Post => { + let mut additional_params = additional_params; + if !query_string.is_empty() { + query_string.split("&").for_each(|param| { + let mut parts = param.split("="); + let key = parts.next().unwrap(); + let value = parts.next().unwrap(); + additional_params.insert(key.to_string(), value.to_string()); + }); + } + let body = serde_json::to_string(&additional_params).unwrap(); + (path_with_params, Some(body)) } }; - let http_method = endpoint.method.unwrap_or_default().to_string(); - let mut headers = vec![ + let http_method_str = http_method.to_string(); + let mut headers: HashMap<_, _> = [ (ARCH_UPSTREAM_HOST_HEADER, endpoint.name.as_str()), - (":method", &http_method), + (":method", &http_method_str), (":path", &path), (":authority", endpoint.name.as_str()), ("content-type", "application/json"), ("x-envoy-max-retries", "3"), - ]; + ] + .into_iter() + .collect(); if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); + headers.insert(REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()); } if self.traceparent.is_some() { - headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap())); + headers.insert(TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()); } let call_args = CallArgs::new( ARCH_INTERNAL_CLUSTER_NAME, &path, - headers, - Some(tool_params_json_str.as_bytes()), + headers.into_iter().collect(), + body.as_deref().map(|s| s.as_bytes()), vec![], Duration::from_secs(5), ); debug!( - "dispatching api call to developer endpoint: {}, path: {}", - endpoint.name, path + "dispatching api call to developer endpoint: {}, path: {}, method: {}", + endpoint.name, path, http_method_str ); - trace!("request body: {}", tool_params_json_str); callout_context.upstream_cluster = Some(endpoint.name.to_owned()); callout_context.upstream_cluster_path = Some(path.to_owned()); diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 9fcaf74d..e9ff2064 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -1,11 +1,11 @@ use common::api::open_ai::{ - ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage, + ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage, }; use common::configuration::Configuration; use http::StatusCode; use proxy_wasm_test_framework::tester::{self, Tester}; use proxy_wasm_test_framework::types::{ - Action, BufferType, LogLevel, MapType, MetricType, ReturnType, + Action, BufferType, LogLevel, MapType, MetricType, ReturnType, }; use serde_yaml::Value; use serial_test::serial; @@ -13,440 +13,440 @@ use std::collections::HashMap; use std::path::Path; fn wasm_module() -> String { - let wasm_file = Path::new("../target/wasm32-wasip1/release/prompt_gateway.wasm"); - assert!( - wasm_file.exists(), - "Run `cargo build --release --target=wasm32-wasip1` first" - ); - wasm_file.to_str().unwrap().to_string() + let wasm_file = Path::new("../target/wasm32-wasip1/release/prompt_gateway.wasm"); + assert!( + wasm_file.exists(), + "Run `cargo build --release --target=wasm32-wasip1` first" + ); + wasm_file.to_str().unwrap().to_string() } fn request_headers_expectations(module: &mut Tester, http_context: i32) { - module - .call_proxy_on_request_headers(http_context, 0, false) - .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) - .returning(Some("/v1/chat/completions")) - .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) - .returning(None) - .expect_log(Some(LogLevel::Trace), None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id")) - .returning(None) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) - .returning(None) - .execute_and_expect(ReturnType::Action(Action::Continue)) - .unwrap(); + module + .call_proxy_on_request_headers(http_context, 0, false) + .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) + .returning(Some("/v1/chat/completions")) + .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) + .returning(None) + .expect_log(Some(LogLevel::Trace), None) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id")) + .returning(None) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) + .returning(None) + .execute_and_expect(ReturnType::Action(Action::Continue)) + .unwrap(); } fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { - module - .call_proxy_on_context_create(http_context, filter_context) - .expect_log(Some(LogLevel::Trace), None) - .execute_and_expect(ReturnType::None) - .unwrap(); + module + .call_proxy_on_context_create(http_context, filter_context) + .expect_log(Some(LogLevel::Trace), None) + .execute_and_expect(ReturnType::None) + .unwrap(); - request_headers_expectations(module, http_context); + request_headers_expectations(module, http_context); - // Request Body - let chat_completions_request_body = "\ + // Request Body + let chat_completions_request_body = "\ {\ - \"messages\": [\ - {\ - \"role\": \"system\",\ - \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ - },\ - {\ - \"role\": \"user\",\ - \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ - }\ - ],\ - \"model\": \"gpt-4\"\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + }\ + ],\ + \"model\": \"gpt-4\"\ }"; - module - .call_proxy_on_request_body( - http_context, - chat_completions_request_body.len() as i32, - true, - ) - .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) - .returning(Some(chat_completions_request_body)) - // The actual call is not important in this test, we just need to grab the token_id - .expect_log(Some(LogLevel::Trace), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/function_calling"), - ("content-type", "application/json"), - (":authority", "model_server"), - ]), - None, - None, - None, - ) - .returning(Some(1)) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::Action(Action::Pause)) - .unwrap(); + module + .call_proxy_on_request_body( + http_context, + chat_completions_request_body.len() as i32, + true, + ) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(chat_completions_request_body)) + // The actual call is not important in this test, we just need to grab the token_id + .expect_log(Some(LogLevel::Trace), None) + .expect_http_call( + Some("arch_internal"), + Some(vec![ + ("x-arch-upstream", "model_server"), + (":method", "POST"), + (":path", "/function_calling"), + ("content-type", "application/json"), + (":authority", "model_server"), + ]), + None, + None, + None, + ) + .returning(Some(1)) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_metric_increment("active_http_calls", 1) + .execute_and_expect(ReturnType::Action(Action::Pause)) + .unwrap(); } fn setup_filter(module: &mut Tester, config: &str) -> i32 { - let filter_context = 1; + let filter_context = 1; - module - .call_proxy_on_context_create(filter_context, 0) - .expect_metric_creation(MetricType::Gauge, "active_http_calls") - .execute_and_expect(ReturnType::None) - .unwrap(); + module + .call_proxy_on_context_create(filter_context, 0) + .expect_metric_creation(MetricType::Gauge, "active_http_calls") + .execute_and_expect(ReturnType::None) + .unwrap(); - module - .call_proxy_on_configure(filter_context, config.len() as i32) - .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) - .returning(Some(config)) - .execute_and_expect(ReturnType::Bool(true)) - .unwrap(); + module + .call_proxy_on_configure(filter_context, config.len() as i32) + .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) + .returning(Some(config)) + .execute_and_expect(ReturnType::Bool(true)) + .unwrap(); - filter_context + filter_context } fn default_config() -> &'static str { - r#" + r#" version: "0.1-beta" listener: - address: 0.0.0.0 - port: 10000 - message_format: huggingface - connect_timeout: 0.005s +address: 0.0.0.0 +port: 10000 +message_format: huggingface +connect_timeout: 0.005s endpoints: - api_server: - endpoint: api_server:80 - connect_timeout: 0.005s +api_server: + endpoint: api_server:80 + connect_timeout: 0.005s llm_providers: - - name: open-ai-gpt-4 - provider_interface: openai - access_key: secret_key - model: gpt-4 - default: true +- name: open-ai-gpt-4 + provider_interface: openai + access_key: secret_key + model: gpt-4 + default: true overrides: - # confidence threshold for prompt target intent matching - prompt_target_intent_matching_threshold: 0.0 +# confidence threshold for prompt target intent matching +prompt_target_intent_matching_threshold: 0.0 system_prompt: | - You are a helpful assistant. +You are a helpful assistant. prompt_guards: - input_guards: - jailbreak: - on_exception: - message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters." +input_guards: + jailbreak: + on_exception: + message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters." prompt_targets: - - name: weather_forecast - description: This function provides realtime weather forecast information for a given city. - parameters: - - name: city - required: true - description: The city for which the weather forecast is requested. - - name: days - description: The number of days for which the weather forecast is requested. - - name: units - description: The units in which the weather forecast is requested. - endpoint: - name: api_server - path: /weather - http_method: POST - system_prompt: | - You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries: - - Use farenheight for temperature - - Use miles per hour for wind speed +- name: weather_forecast + description: This function provides realtime weather forecast information for a given city. + parameters: + - name: city + required: true + description: The city for which the weather forecast is requested. + - name: days + description: The number of days for which the weather forecast is requested. + - name: units + description: The units in which the weather forecast is requested. + endpoint: + name: api_server + path: /weather + http_method: POST + system_prompt: | + You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries: + - Use farenheight for temperature + - Use miles per hour for wind speed ratelimits: - - model: gpt-4 - selector: - key: selector-key - value: selector-value - limit: - tokens: 1 - unit: minute +- model: gpt-4 + selector: + key: selector-key + value: selector-value + limit: + tokens: 1 + unit: minute "# } #[test] #[serial] fn prompt_gateway_successful_request_to_open_ai_chat_completions() { - let args = tester::MockSettings { - wasm_path: wasm_module(), - quiet: false, - allow_unexpected: false, - }; - let mut module = tester::mock(args).unwrap(); + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); - module - .call_start() - .execute_and_expect(ReturnType::None) - .unwrap(); + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); - // Setup Filter - let filter_context = setup_filter(&mut module, default_config()); + // Setup Filter + let filter_context = setup_filter(&mut module, default_config()); - // Setup HTTP Stream - let http_context = 2; + // Setup HTTP Stream + let http_context = 2; - module - .call_proxy_on_context_create(http_context, filter_context) - .expect_log(Some(LogLevel::Trace), None) - .execute_and_expect(ReturnType::None) - .unwrap(); + module + .call_proxy_on_context_create(http_context, filter_context) + .expect_log(Some(LogLevel::Trace), None) + .execute_and_expect(ReturnType::None) + .unwrap(); - request_headers_expectations(&mut module, http_context); + request_headers_expectations(&mut module, http_context); - // Request Body - let chat_completions_request_body = "\ - {\ - \"messages\": [\ - {\ - \"role\": \"system\",\ - \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ - },\ - {\ - \"role\": \"user\",\ - \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ - }\ - ],\ - \"model\": \"gpt-4\"\ - }"; + // Request Body + let chat_completions_request_body = "\ + {\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + }\ + ],\ + \"model\": \"gpt-4\"\ + }"; - module - .call_proxy_on_request_body( - http_context, - chat_completions_request_body.len() as i32, - true, - ) - .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) - .returning(Some(chat_completions_request_body)) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_http_call(Some("arch_internal"), None, None, None, None) - .returning(Some(4)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::Action(Action::Pause)) - .unwrap(); + module + .call_proxy_on_request_body( + http_context, + chat_completions_request_body.len() as i32, + true, + ) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(chat_completions_request_body)) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_http_call(Some("arch_internal"), None, None, None, None) + .returning(Some(4)) + .expect_metric_increment("active_http_calls", 1) + .execute_and_expect(ReturnType::Action(Action::Pause)) + .unwrap(); } #[test] #[serial] fn prompt_gateway_bad_request_to_open_ai_chat_completions() { - let args = tester::MockSettings { - wasm_path: wasm_module(), - quiet: false, - allow_unexpected: false, - }; - let mut module = tester::mock(args).unwrap(); + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); - module - .call_start() - .execute_and_expect(ReturnType::None) - .unwrap(); + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); - // Setup Filter - let filter_context = setup_filter(&mut module, default_config()); + // Setup Filter + let filter_context = setup_filter(&mut module, default_config()); - // Setup HTTP Stream - let http_context = 2; + // Setup HTTP Stream + let http_context = 2; - module - .call_proxy_on_context_create(http_context, filter_context) - .expect_log(Some(LogLevel::Trace), None) - .execute_and_expect(ReturnType::None) - .unwrap(); + module + .call_proxy_on_context_create(http_context, filter_context) + .expect_log(Some(LogLevel::Trace), None) + .execute_and_expect(ReturnType::None) + .unwrap(); - request_headers_expectations(&mut module, http_context); + request_headers_expectations(&mut module, http_context); - // Request Body - let incomplete_chat_completions_request_body = "\ - {\ - \"messages\": [\ - {\ - \"role\": \"system\",\ - },\ - {\ - \"role\": \"user\",\ - \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ - }\ - ]\ - }"; + // Request Body + let incomplete_chat_completions_request_body = "\ + {\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + }\ + ]\ + }"; - module - .call_proxy_on_request_body( - http_context, - incomplete_chat_completions_request_body.len() as i32, - true, - ) - .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) - .returning(Some(incomplete_chat_completions_request_body)) - .expect_log(Some(LogLevel::Trace), None) - .expect_send_local_response( - Some(StatusCode::BAD_REQUEST.as_u16().into()), - None, - None, - None, - ) - .expect_log(Some(LogLevel::Trace), None) - .execute_and_expect(ReturnType::Action(Action::Pause)) - .unwrap(); + module + .call_proxy_on_request_body( + http_context, + incomplete_chat_completions_request_body.len() as i32, + true, + ) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(incomplete_chat_completions_request_body)) + .expect_log(Some(LogLevel::Trace), None) + .expect_send_local_response( + Some(StatusCode::BAD_REQUEST.as_u16().into()), + None, + None, + None, + ) + .expect_log(Some(LogLevel::Trace), None) + .execute_and_expect(ReturnType::Action(Action::Pause)) + .unwrap(); } #[test] #[serial] fn prompt_gateway_request_to_llm_gateway() { - let args = tester::MockSettings { - wasm_path: wasm_module(), - quiet: false, - allow_unexpected: false, - }; - let mut module = tester::mock(args).unwrap(); + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); - module - .call_start() - .execute_and_expect(ReturnType::None) - .unwrap(); + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); - // Setup Filter - let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap(); - config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000; - let config_str = serde_json::to_string(&config).unwrap(); + // Setup Filter + let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap(); + config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000; + let config_str = serde_json::to_string(&config).unwrap(); - let filter_context = setup_filter(&mut module, &config_str); + let filter_context = setup_filter(&mut module, &config_str); - // Setup HTTP Stream - let http_context = 2; + // Setup HTTP Stream + let http_context = 2; - normal_flow(&mut module, filter_context, http_context); + normal_flow(&mut module, filter_context, http_context); - let arch_fc_resp = ChatCompletionsResponse { - usage: Some(Usage { - completion_tokens: 0, - }), - choices: vec![Choice { - finish_reason: Some("test".to_string()), - index: Some(0), - message: Message { - role: "system".to_string(), - content: None, - tool_calls: Some(vec![ToolCall { - id: String::from("test"), - tool_type: ToolType::Function, - function: FunctionCallDetail { - name: String::from("weather_forecast"), - arguments: HashMap::from([( - String::from("city"), - Value::String(String::from("seattle")), - )]), - }, - }]), - model: None, - tool_call_id: None, - }, - }], - model: String::from("test"), - metadata: None, - }; + let arch_fc_resp = ChatCompletionsResponse { + usage: Some(Usage { + completion_tokens: 0, + }), + choices: vec![Choice { + finish_reason: Some("test".to_string()), + index: Some(0), + message: Message { + role: "system".to_string(), + content: None, + tool_calls: Some(vec![ToolCall { + id: String::from("test"), + tool_type: ToolType::Function, + function: FunctionCallDetail { + name: String::from("weather_forecast"), + arguments: HashMap::from([( + String::from("city"), + Value::String(String::from("seattle")), + )]), + }, + }]), + model: None, + tool_call_id: None, + }, + }], + model: String::from("test"), + metadata: None, + }; - let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); - module - .call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&arch_fc_resp_str)) - .expect_log(Some(LogLevel::Warn), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "api_server"), - (":method", "POST"), - (":path", "/weather"), - (":authority", "api_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ]), - None, - None, - None, - ) - .returning(Some(2)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); + let expected_body = "{\"city\":\"seattle\"}"; + let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); + module + .call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&arch_fc_resp_str)) + .expect_log(Some(LogLevel::Warn), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_http_call( + Some("arch_internal"), + Some(vec![ + (":method", "POST"), + ("content-type", "application/json"), + ("x-arch-upstream", "api_server"), + (":authority", "api_server"), + ("x-envoy-max-retries", "3"), + (":path", "/weather"), + ]), + Some(expected_body), + None, + None, + ) + .returning(Some(2)) + .expect_metric_increment("active_http_calls", 1) + .execute_and_expect(ReturnType::None) + .unwrap(); - let body_text = String::from("test body"); - module - .call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&body_text)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status")) - .returning(Some("200")) - .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) - .execute_and_expect(ReturnType::None) - .unwrap(); + let body_text = String::from("test body"); + module + .call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&body_text)) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status")) + .returning(Some("200")) + .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) + .execute_and_expect(ReturnType::None) + .unwrap(); - let chat_completion_response = ChatCompletionsResponse { - usage: Some(Usage { - completion_tokens: 0, - }), - choices: vec![Choice { - finish_reason: Some("test".to_string()), - index: Some(0), - message: Message { - role: "assistant".to_string(), - content: Some("hello from fake llm gateway".to_string()), - model: None, - tool_calls: None, - tool_call_id: None, - }, - }], - model: String::from("test"), - metadata: None, - }; + let chat_completion_response = ChatCompletionsResponse { + usage: Some(Usage { + completion_tokens: 0, + }), + choices: vec![Choice { + finish_reason: Some("test".to_string()), + index: Some(0), + message: Message { + role: "assistant".to_string(), + content: Some("hello from fake llm gateway".to_string()), + model: None, + tool_calls: None, + tool_call_id: None, + }, + }], + model: String::from("test"), + metadata: None, + }; - let chat_completion_response_str = serde_json::to_string(&chat_completion_response).unwrap(); - module - .call_proxy_on_response_body( - http_context, - chat_completion_response_str.len() as i32, - true, - ) - .expect_get_buffer_bytes(Some(BufferType::HttpResponseBody)) - .returning(Some(chat_completion_response_str.as_str())) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .execute_and_expect(ReturnType::Action(Action::Continue)) - .unwrap(); + let chat_completion_response_str = serde_json::to_string(&chat_completion_response).unwrap(); + module + .call_proxy_on_response_body( + http_context, + chat_completion_response_str.len() as i32, + true, + ) + .expect_get_buffer_bytes(Some(BufferType::HttpResponseBody)) + .returning(Some(chat_completion_response_str.as_str())) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) + .execute_and_expect(ReturnType::Action(Action::Continue)) + .unwrap(); }