From 7e9059f4f849ded8d148feb3f568b0516f9c0f28 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Wed, 5 Feb 2025 12:53:48 -0800 Subject: [PATCH] fix tests --- crates/common/src/path.rs | 56 +++++++++++++++++---- crates/prompt_gateway/src/stream_context.rs | 13 ++--- demos/spotify_demo/arch_config.yaml | 18 +++++-- 3 files changed, 64 insertions(+), 23 deletions(-) diff --git a/crates/common/src/path.rs b/crates/common/src/path.rs index 95138e10..45a5ef33 100644 --- a/crates/common/src/path.rs +++ b/crates/common/src/path.rs @@ -1,9 +1,12 @@ use std::collections::{HashMap, HashSet}; use urlencoding; +use crate::configuration::Parameter; + pub fn replace_params_in_path( path: &str, - params: &HashMap, + tool_params: &HashMap, + prompt_target_params: &Vec, ) -> Result { let mut result = String::new(); let mut in_param = false; @@ -16,7 +19,7 @@ pub fn replace_params_in_path( } else if c == '}' { in_param = false; let param_name = current_param.clone(); - if let Some(value) = params.get(¶m_name) { + if let Some(value) = tool_params.get(¶m_name) { let value = urlencoding::encode(value); result.push_str(value.into_owned().as_str()); vars_replaced.insert(param_name.clone()); @@ -32,9 +35,10 @@ pub fn replace_params_in_path( } // add the remaining params in path - for (param_name, value) in params.iter() { + for (param_name, value) in tool_params.iter() { let value = urlencoding::encode(value); if !vars_replaced.contains(param_name) { + vars_replaced.insert(param_name.clone()); if result.contains("?") { result.push_str(&format!("&{}={}", param_name, value)); } else { @@ -43,11 +47,32 @@ pub fn replace_params_in_path( } } + // add default values + for param in prompt_target_params.iter() { + if !vars_replaced.contains(¶m.name) && param.default.is_some() { + if result.contains("?") { + result.push_str(&format!( + "&{}={}", + param.name, + param.default.as_ref().unwrap() + )); + } else { + result.push_str(&format!( + "?{}={}", + param.name, + param.default.as_ref().unwrap() + )); + } + } + } + Ok(result) } #[cfg(test)] mod test { + use crate::configuration::Parameter; + #[test] fn test_replace_path() { let path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}"; @@ -57,18 +82,31 @@ mod test { ] .into_iter() .collect(); + let prompt_target_params = vec![Parameter { + name: "country".to_string(), + parameter_type: None, + description: "test target".to_string(), + required: None, + enum_values: None, + default: Some("US".to_string()), + in_path: None, + format: None, + }]; + assert_eq!( - super::replace_params_in_path(path, ¶ms), + super::replace_params_in_path(path, ¶ms, &prompt_target_params), Ok( - "/cluster.open-cluster-management.io/v1/managedclusters/test1?hello=hello%20world" + "/cluster.open-cluster-management.io/v1/managedclusters/test1?hello=hello%20world&country=US" .to_string() ) ); + let prompt_target_params = vec![]; + let path = "/cluster.open-cluster-management.io/v1/managedclusters"; let params = vec![].into_iter().collect(); assert_eq!( - super::replace_params_in_path(path, ¶ms), + super::replace_params_in_path(path, ¶ms, &prompt_target_params), Ok("/cluster.open-cluster-management.io/v1/managedclusters".to_string()) ); @@ -77,7 +115,7 @@ mod test { .into_iter() .collect(); assert_eq!( - super::replace_params_in_path(path, ¶ms), + super::replace_params_in_path(path, ¶ms, &prompt_target_params), Ok("/foo/qux/baz".to_string()) ); @@ -89,7 +127,7 @@ mod test { .into_iter() .collect(); assert_eq!( - super::replace_params_in_path(path, ¶ms), + super::replace_params_in_path(path, ¶ms, &prompt_target_params), Ok("/foo/qux/baz/quux".to_string()) ); @@ -98,7 +136,7 @@ mod test { .into_iter() .collect(); assert_eq!( - super::replace_params_in_path(path, ¶ms), + super::replace_params_in_path(path, ¶ms, &prompt_target_params), Err("Missing value for parameter `qux`".to_string()) ); } diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index c0707bd6..26edf6dc 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -276,22 +276,17 @@ impl StreamContext { let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); - let mut tool_params = self.tool_calls.as_ref().unwrap()[0] + let tool_params = self.tool_calls.as_ref().unwrap()[0] .function .arguments .clone(); - tool_params.insert( - String::from(MESSAGES_KEY), - serde_yaml::to_value(&callout_context.request_body.messages).unwrap(), - ); - - let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); let endpoint = prompt_target.endpoint.unwrap(); let path: String = endpoint.path.unwrap_or(String::from("/")); + let prompt_target_params = prompt_target.parameters.unwrap_or_default(); // only add params that are of string, number and bool type - let url_params = tool_params + let tool_url_params = tool_params .iter() .filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool()) .map(|(key, value)| match value { @@ -305,7 +300,7 @@ impl StreamContext { }) .collect::>(); - let path = match common::path::replace_params_in_path(&path, &url_params) { + let path = match common::path::replace_params_in_path(&path, &tool_url_params, &prompt_target_params) { Ok(path) => path, Err(e) => { return self.send_server_error( diff --git a/demos/spotify_demo/arch_config.yaml b/demos/spotify_demo/arch_config.yaml index 7dd8dcb2..96a898ef 100644 --- a/demos/spotify_demo/arch_config.yaml +++ b/demos/spotify_demo/arch_config.yaml @@ -99,19 +99,27 @@ prompt_targets: Authorization: "Bearer $SPOTIFY_CLIENT_KEY" description: Get a list of new album releases featured in Spotify (shown, for example, on a Spotify player’s “Browse” tab). - - name: search_for_shows_and_podcasts + - name: get_catalog_information parameters: - name: q - description: The search filter to narrow down results + description: keywords to search about catalog required: true type: str - name: type type: str description: The type of catalog item - default: show + enum: + - album + - artist + - playlist + - track + - show + - episode + - audiobook + required: true - name: market type: str - description: A country code + description: A country code for catalog default: US - name: limit type: integer @@ -122,4 +130,4 @@ prompt_targets: path: /v1/search http_headers: Authorization: "Bearer $SPOTIFY_CLIENT_KEY" - description: get Spotify catalog information about shows and podcasts that match the keyword filter. + description: Get Spotify catalog information about albums, artists, playlists, tracks, shows, episodes or audiobooks that match a keyword string.