Encode parameter values in http path and ...

- don't send param values in request body in http get request
- send param values in http post request
This commit is contained in:
Adil Hafeez 2025-02-05 17:39:44 -08:00
parent fa089ef32d
commit 60471286bb
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
5 changed files with 859 additions and 411 deletions

302
crates/Cargo.lock generated
View file

@ -234,6 +234,8 @@ dependencies = [
"serde_yaml",
"thiserror",
"tiktoken-rs",
"url",
"urlencoding",
]
[[package]]
@ -477,6 +479,17 @@ dependencies = [
"winapi",
]
[[package]]
name = "displaydoc"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "duration-string"
version = "0.3.0"
@ -557,6 +570,15 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
[[package]]
name = "form_urlencoded"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.31"
@ -782,12 +804,151 @@ dependencies = [
"itoa",
]
[[package]]
name = "icu_collections"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526"
dependencies = [
"displaydoc",
"yoke",
"zerofrom",
"zerovec",
]
[[package]]
name = "icu_locid"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637"
dependencies = [
"displaydoc",
"litemap",
"tinystr",
"writeable",
"zerovec",
]
[[package]]
name = "icu_locid_transform"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e"
dependencies = [
"displaydoc",
"icu_locid",
"icu_locid_transform_data",
"icu_provider",
"tinystr",
"zerovec",
]
[[package]]
name = "icu_locid_transform_data"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e"
[[package]]
name = "icu_normalizer"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f"
dependencies = [
"displaydoc",
"icu_collections",
"icu_normalizer_data",
"icu_properties",
"icu_provider",
"smallvec",
"utf16_iter",
"utf8_iter",
"write16",
"zerovec",
]
[[package]]
name = "icu_normalizer_data"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516"
[[package]]
name = "icu_properties"
version = "1.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5"
dependencies = [
"displaydoc",
"icu_collections",
"icu_locid_transform",
"icu_properties_data",
"icu_provider",
"tinystr",
"zerovec",
]
[[package]]
name = "icu_properties_data"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569"
[[package]]
name = "icu_provider"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9"
dependencies = [
"displaydoc",
"icu_locid",
"icu_provider_macros",
"stable_deref_trait",
"tinystr",
"writeable",
"yoke",
"zerofrom",
"zerovec",
]
[[package]]
name = "icu_provider_macros"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "id-arena"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005"
[[package]]
name = "idna"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e"
dependencies = [
"idna_adapter",
"smallvec",
"utf8_iter",
]
[[package]]
name = "idna_adapter"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71"
dependencies = [
"icu_normalizer",
"icu_properties",
]
[[package]]
name = "indexmap"
version = "2.6.0"
@ -883,6 +1044,12 @@ version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
[[package]]
name = "litemap"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104"
[[package]]
name = "llm_gateway"
version = "0.1.0"
@ -1028,6 +1195,12 @@ version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "percent-encoding"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "pin-project-lite"
version = "0.2.14"
@ -1547,6 +1720,17 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "synstructure"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "target-lexicon"
version = "0.12.16"
@ -1606,6 +1790,16 @@ dependencies = [
"rustc-hash",
]
[[package]]
name = "tinystr"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f"
dependencies = [
"displaydoc",
"zerovec",
]
[[package]]
name = "toml"
version = "0.8.19"
@ -1676,6 +1870,35 @@ version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]]
name = "url"
version = "2.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60"
dependencies = [
"form_urlencoded",
"idna",
"percent-encoding",
]
[[package]]
name = "urlencoding"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]]
name = "utf16_iter"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246"
[[package]]
name = "utf8_iter"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "uuid"
version = "1.11.0"
@ -2189,12 +2412,48 @@ dependencies = [
"wasmparser 0.212.0",
]
[[package]]
name = "write16"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936"
[[package]]
name = "writeable"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51"
[[package]]
name = "yansi"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
[[package]]
name = "yoke"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40"
dependencies = [
"serde",
"stable_deref_trait",
"yoke-derive",
"zerofrom",
]
[[package]]
name = "yoke-derive"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"synstructure",
]
[[package]]
name = "zerocopy"
version = "0.7.35"
@ -2216,6 +2475,49 @@ dependencies = [
"syn 2.0.79",
]
[[package]]
name = "zerofrom"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e"
dependencies = [
"zerofrom-derive",
]
[[package]]
name = "zerofrom-derive"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
"synstructure",
]
[[package]]
name = "zerovec"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079"
dependencies = [
"yoke",
"zerofrom",
"zerovec-derive",
]
[[package]]
name = "zerovec-derive"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "zstd"
version = "0.13.2"

View file

@ -16,6 +16,8 @@ tiktoken-rs = "0.5.9"
rand = "0.8.5"
serde_json = "1.0"
hex = "0.4.3"
urlencoding = "2.1.3"
url = "2.5.4"
[dev-dependencies]
pretty_assertions = "1.4.1"

View file

@ -1,21 +1,30 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use url::Url;
use urlencoding;
use crate::configuration::Parameter;
pub fn replace_params_in_path(
path: &str,
params: &HashMap<String, String>,
) -> Result<String, String> {
let mut result = String::new();
let mut in_param = false;
tool_params: &HashMap<String, String>,
prompt_target_params: &[Parameter],
) -> Result<(String, String, HashMap<String, String>), String> {
let mut query_string_replaced = String::new();
let mut current_param = String::new();
let mut vars_replaced = HashSet::new();
let mut params: HashMap<String, String> = HashMap::new();
let mut in_param = false;
for c in path.chars() {
if c == '{' {
in_param = true;
} else if c == '}' {
in_param = false;
let param_name = current_param.clone();
if let Some(value) = params.get(&param_name) {
result.push_str(value);
if let Some(value) = tool_params.get(&param_name) {
let value = urlencoding::encode(value);
query_string_replaced.push_str(value.into_owned().as_str());
vars_replaced.insert(param_name.clone());
} else {
return Err(format!("Missing value for parameter `{}`", param_name));
}
@ -23,31 +32,106 @@ pub fn replace_params_in_path(
} else if in_param {
current_param.push(c);
} else {
result.push(c);
query_string_replaced.push(c);
}
}
Ok(result)
// add the remaining params in path
for (param_name, value) in tool_params.iter() {
let value = urlencoding::encode(value).into_owned();
if !vars_replaced.contains(param_name) {
vars_replaced.insert(param_name.clone());
params.insert(param_name.clone(), value.clone());
if query_string_replaced.contains("?") {
query_string_replaced.push_str(&format!("&{}={}", param_name, value));
} else {
query_string_replaced.push_str(&format!("?{}={}", param_name, value));
}
}
}
// add default values
for param in prompt_target_params.iter() {
if !vars_replaced.contains(&param.name) && param.default.is_some() {
params.insert(param.name.clone(), param.default.clone().unwrap());
if query_string_replaced.contains("?") {
query_string_replaced.push_str(&format!(
"&{}={}",
param.name,
param.default.as_ref().unwrap()
));
} else {
query_string_replaced.push_str(&format!(
"?{}={}",
param.name,
param.default.as_ref().unwrap()
));
}
}
}
let parsed_uri = Url::parse("http://dummy.com").unwrap();
let parsed_uri = parsed_uri
.join(&query_string_replaced)
.map_err(|e| e.to_string())?;
let query_string = parsed_uri.query().unwrap_or("");
let path_uri = parsed_uri.path();
Ok((path_uri.to_string(), query_string.to_string(), params))
}
#[cfg(test)]
mod test {
use std::collections::HashMap;
use crate::configuration::Parameter;
#[test]
fn test_replace_path() {
let path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}";
let params = vec![("cluster_name".to_string(), "test1".to_string())]
.into_iter()
.collect();
let params = vec![
("cluster_name".to_string(), "test1".to_string()),
("hello".to_string(), "hello world".to_string()),
]
.into_iter()
.collect();
let prompt_target_params = vec![Parameter {
name: "country".to_string(),
parameter_type: None,
description: "test target".to_string(),
required: None,
enum_values: None,
default: Some("US".to_string()),
in_path: None,
format: None,
}];
let out_params: HashMap<String, String> = vec![
("country".to_string(), "US".to_string()),
("hello".to_string(), "hello%20world".to_string()),
]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/cluster.open-cluster-management.io/v1/managedclusters/test1".to_string())
super::replace_params_in_path(path, &params, &prompt_target_params),
Ok((
"/cluster.open-cluster-management.io/v1/managedclusters/test1".to_string(),
"hello=hello%20world&country=US".to_string(),
out_params.clone()
))
);
let out_params = HashMap::new();
let prompt_target_params = vec![];
let path = "/cluster.open-cluster-management.io/v1/managedclusters";
let params = vec![].into_iter().collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/cluster.open-cluster-management.io/v1/managedclusters".to_string())
super::replace_params_in_path(path, &params, &prompt_target_params),
Ok((
"/cluster.open-cluster-management.io/v1/managedclusters".to_string(),
"".to_string(),
out_params
))
);
let path = "/foo/{bar}/baz";
@ -55,8 +139,8 @@ mod test {
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/foo/qux/baz".to_string())
super::replace_params_in_path(path, &params, &prompt_target_params),
Ok(("/foo/qux/baz".to_string(), "".to_string(), HashMap::new()))
);
let path = "/foo/{bar}/baz/{qux}";
@ -67,8 +151,45 @@ mod test {
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/foo/qux/baz/quux".to_string())
super::replace_params_in_path(path, &params, &prompt_target_params),
Ok((
"/foo/qux/baz/quux".to_string(),
"".to_string(),
HashMap::new()
))
);
let path = "/foo/{bar}/baz/{qux}?hello=world";
let params = vec![
("bar".to_string(), "qux".to_string()),
("qux".to_string(), "quux".to_string()),
]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params, &prompt_target_params),
Ok((
"/foo/qux/baz/quux".to_string(),
"hello=world".to_string(),
HashMap::new()
))
);
let path = "/foo/{bar}/baz/{qux}?hello={hello}";
let params = vec![
("bar".to_string(), "qux".to_string()),
("qux".to_string(), "quux".to_string()),
("hello".to_string(), "hello world".to_string()),
]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params, &prompt_target_params),
Ok((
"/foo/qux/baz/quux".to_string(),
"hello=hello%20world".to_string(),
HashMap::new()
))
);
let path = "/foo/{bar}/baz/{qux}";
@ -76,7 +197,7 @@ mod test {
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
super::replace_params_in_path(path, &params, &prompt_target_params),
Err("Missing value for parameter `qux`".to_string())
);
}

View file

@ -3,7 +3,7 @@ use common::api::open_ai::{
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
ChatCompletionsResponse, Message, ModelServerResponse, ToolCall,
};
use common::configuration::{Overrides, PromptTarget, Tracing};
use common::configuration::{HttpMethod, Overrides, PromptTarget, Tracing};
use common::consts::{
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE,
@ -276,22 +276,18 @@ impl StreamContext {
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
let mut tool_params = self.tool_calls.as_ref().unwrap()[0]
let tool_params = self.tool_calls.as_ref().unwrap()[0]
.function
.arguments
.clone();
tool_params.insert(
String::from(MESSAGES_KEY),
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
);
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
let endpoint = prompt_target.endpoint.unwrap();
let path: String = endpoint.path.unwrap_or(String::from("/"));
let prompt_target_params = prompt_target.parameters.unwrap_or_default();
let http_method = endpoint.method.unwrap_or_default();
// only add params that are of string, number and bool type
let url_params = tool_params
let tool_url_params = tool_params
.iter()
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
.map(|(key, value)| match value {
@ -305,50 +301,77 @@ impl StreamContext {
})
.collect::<HashMap<String, String>>();
let path = match common::path::replace_params_in_path(&path, &url_params) {
Ok(path) => path,
Err(e) => {
return self.send_server_error(
ServerError::BadRequest {
why: format!("error replacing params in path: {}", e),
},
Some(StatusCode::BAD_REQUEST),
);
let (path_with_params, query_string, additional_params) =
match common::path::replace_params_in_path(
&path,
&tool_url_params,
&prompt_target_params,
) {
Ok((path, query_string, additional_params)) => {
(path, query_string, additional_params)
}
Err(e) => {
return self.send_server_error(
ServerError::BadRequest {
why: format!("error replacing params in path: {}", e),
},
Some(StatusCode::BAD_REQUEST),
);
}
};
let (path, body) = match http_method {
HttpMethod::Get => {
(format!("{}?{}", path_with_params, query_string), None)
}
HttpMethod::Post => {
let mut additional_params = additional_params;
if !query_string.is_empty() {
query_string.split("&").for_each(|param| {
let mut parts = param.split("=");
let key = parts.next().unwrap();
let value = parts.next().unwrap();
additional_params.insert(key.to_string(), value.to_string());
});
}
let body = serde_json::to_string(&additional_params).unwrap();
(path_with_params, Some(body))
}
};
let http_method = endpoint.method.unwrap_or_default().to_string();
let mut headers = vec![
let http_method_str = http_method.to_string();
let mut headers: HashMap<_, _> = [
(ARCH_UPSTREAM_HOST_HEADER, endpoint.name.as_str()),
(":method", &http_method),
(":method", &http_method_str),
(":path", &path),
(":authority", endpoint.name.as_str()),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
];
]
.into_iter()
.collect();
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
headers.insert(REQUEST_ID_HEADER, self.request_id.as_ref().unwrap());
}
if self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
headers.insert(TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap());
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&path,
headers,
Some(tool_params_json_str.as_bytes()),
headers.into_iter().collect(),
body.as_deref().map(|s| s.as_bytes()),
vec![],
Duration::from_secs(5),
);
debug!(
"dispatching api call to developer endpoint: {}, path: {}",
endpoint.name, path
"dispatching api call to developer endpoint: {}, path: {}, method: {}",
endpoint.name, path, http_method_str
);
trace!("request body: {}", tool_params_json_str);
callout_context.upstream_cluster = Some(endpoint.name.to_owned());
callout_context.upstream_cluster_path = Some(path.to_owned());

View file

@ -1,11 +1,11 @@
use common::api::open_ai::{
ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage,
ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage,
};
use common::configuration::Configuration;
use http::StatusCode;
use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
};
use serde_yaml::Value;
use serial_test::serial;
@ -13,440 +13,440 @@ use std::collections::HashMap;
use std::path::Path;
fn wasm_module() -> String {
let wasm_file = Path::new("../target/wasm32-wasip1/release/prompt_gateway.wasm");
assert!(
wasm_file.exists(),
"Run `cargo build --release --target=wasm32-wasip1` first"
);
wasm_file.to_str().unwrap().to_string()
let wasm_file = Path::new("../target/wasm32-wasip1/release/prompt_gateway.wasm");
assert!(
wasm_file.exists(),
"Run `cargo build --release --target=wasm32-wasip1` first"
);
wasm_file.to_str().unwrap().to_string()
}
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
module
.call_proxy_on_request_headers(http_context, 0, false)
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
.returning(Some("/v1/chat/completions"))
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
.expect_log(Some(LogLevel::Trace), None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent"))
.returning(None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
module
.call_proxy_on_request_headers(http_context, 0, false)
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
.returning(Some("/v1/chat/completions"))
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
.expect_log(Some(LogLevel::Trace), None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent"))
.returning(None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(module, http_context);
request_headers_expectations(module, http_context);
// Request Body
let chat_completions_request_body = "\
// Request Body
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
],\
\"model\": \"gpt-4\"\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
],\
\"model\": \"gpt-4\"\
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "model_server"),
(":method", "POST"),
(":path", "/function_calling"),
("content-type", "application/json"),
(":authority", "model_server"),
]),
None,
None,
None,
)
.returning(Some(1))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "model_server"),
(":method", "POST"),
(":path", "/function_calling"),
("content-type", "application/json"),
(":authority", "model_server"),
]),
None,
None,
None,
)
.returning(Some(1))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
fn setup_filter(module: &mut Tester, config: &str) -> i32 {
let filter_context = 1;
let filter_context = 1;
module
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
filter_context
filter_context
}
fn default_config() -> &'static str {
r#"
r#"
version: "0.1-beta"
listener:
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
endpoints:
api_server:
endpoint: api_server:80
connect_timeout: 0.005s
api_server:
endpoint: api_server:80
connect_timeout: 0.005s
llm_providers:
- name: open-ai-gpt-4
provider_interface: openai
access_key: secret_key
model: gpt-4
default: true
- name: open-ai-gpt-4
provider_interface: openai
access_key: secret_key
model: gpt-4
default: true
overrides:
# confidence threshold for prompt target intent matching
prompt_target_intent_matching_threshold: 0.0
# confidence threshold for prompt target intent matching
prompt_target_intent_matching_threshold: 0.0
system_prompt: |
You are a helpful assistant.
You are a helpful assistant.
prompt_guards:
input_guards:
jailbreak:
on_exception:
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
input_guards:
jailbreak:
on_exception:
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
prompt_targets:
- name: weather_forecast
description: This function provides realtime weather forecast information for a given city.
parameters:
- name: city
required: true
description: The city for which the weather forecast is requested.
- name: days
description: The number of days for which the weather forecast is requested.
- name: units
description: The units in which the weather forecast is requested.
endpoint:
name: api_server
path: /weather
http_method: POST
system_prompt: |
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
- name: weather_forecast
description: This function provides realtime weather forecast information for a given city.
parameters:
- name: city
required: true
description: The city for which the weather forecast is requested.
- name: days
description: The number of days for which the weather forecast is requested.
- name: units
description: The units in which the weather forecast is requested.
endpoint:
name: api_server
path: /weather
http_method: POST
system_prompt: |
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
ratelimits:
- model: gpt-4
selector:
key: selector-key
value: selector-value
limit:
tokens: 1
unit: minute
- model: gpt-4
selector:
key: selector-key
value: selector-value
limit:
tokens: 1
unit: minute
"#
}
#[test]
#[serial]
fn prompt_gateway_successful_request_to_open_ai_chat_completions() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
// Setup HTTP Stream
let http_context = 2;
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(&mut module, http_context);
request_headers_expectations(&mut module, http_context);
// Request Body
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
],\
\"model\": \"gpt-4\"\
}";
// Request Body
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
],\
\"model\": \"gpt-4\"\
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(Some("arch_internal"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(Some("arch_internal"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
#[test]
#[serial]
fn prompt_gateway_bad_request_to_open_ai_chat_completions() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
// Setup HTTP Stream
let http_context = 2;
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::None)
.unwrap();
request_headers_expectations(&mut module, http_context);
request_headers_expectations(&mut module, http_context);
// Request Body
let incomplete_chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]\
}";
// Request Body
let incomplete_chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]\
}";
module
.call_proxy_on_request_body(
http_context,
incomplete_chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,
None,
None,
)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
module
.call_proxy_on_request_body(
http_context,
incomplete_chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,
None,
None,
)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
#[test]
#[serial]
fn prompt_gateway_request_to_llm_gateway() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap();
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
let config_str = serde_json::to_string(&config).unwrap();
// Setup Filter
let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap();
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
let config_str = serde_json::to_string(&config).unwrap();
let filter_context = setup_filter(&mut module, &config_str);
let filter_context = setup_filter(&mut module, &config_str);
// Setup HTTP Stream
let http_context = 2;
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "system".to_string(),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)]),
},
}]),
model: None,
tool_call_id: None,
},
}],
model: String::from("test"),
metadata: None,
};
let arch_fc_resp = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "system".to_string(),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)]),
},
}]),
model: None,
tool_call_id: None,
},
}],
model: String::from("test"),
metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "api_server"),
(":method", "POST"),
(":path", "/weather"),
(":authority", "api_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
]),
None,
None,
None,
)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let expected_body = "{\"city\":\"seattle\"}";
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
(":method", "POST"),
("content-type", "application/json"),
("x-arch-upstream", "api_server"),
(":authority", "api_server"),
("x-envoy-max-retries", "3"),
(":path", "/weather"),
]),
Some(expected_body),
None,
None,
)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
.returning(Some("200"))
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
.returning(Some("200"))
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::None)
.unwrap();
let chat_completion_response = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "assistant".to_string(),
content: Some("hello from fake llm gateway".to_string()),
model: None,
tool_calls: None,
tool_call_id: None,
},
}],
model: String::from("test"),
metadata: None,
};
let chat_completion_response = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "assistant".to_string(),
content: Some("hello from fake llm gateway".to_string()),
model: None,
tool_calls: None,
tool_call_id: None,
},
}],
model: String::from("test"),
metadata: None,
};
let chat_completion_response_str = serde_json::to_string(&chat_completion_response).unwrap();
module
.call_proxy_on_response_body(
http_context,
chat_completion_response_str.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpResponseBody))
.returning(Some(chat_completion_response_str.as_str()))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
let chat_completion_response_str = serde_json::to_string(&chat_completion_response).unwrap();
module
.call_proxy_on_response_body(
http_context,
chat_completion_response_str.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpResponseBody))
.returning(Some(chat_completion_response_str.as_str()))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}