From 16f5025bc9e3afc311be806d678f5a88c4e4d831 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 4 Feb 2025 17:42:27 -0800 Subject: [PATCH] path replacement --- crates/Cargo.lock | 7 +++++ crates/common/Cargo.toml | 1 + crates/common/src/path.rs | 34 +++++++++++++++++---- crates/prompt_gateway/src/stream_context.rs | 2 -- demos/spotify_demo/arch_config.yaml | 2 +- 5 files changed, 37 insertions(+), 9 deletions(-) diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 98157733..3d1d9f7e 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -234,6 +234,7 @@ dependencies = [ "serde_yaml", "thiserror", "tiktoken-rs", + "urlencoding", ] [[package]] @@ -1676,6 +1677,12 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "uuid" version = "1.11.0" diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index 84aa636c..043d8657 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -16,6 +16,7 @@ tiktoken-rs = "0.5.9" rand = "0.8.5" serde_json = "1.0" hex = "0.4.3" +urlencoding = "2.1.3" [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/crates/common/src/path.rs b/crates/common/src/path.rs index 3bf2aed5..95138e10 100644 --- a/crates/common/src/path.rs +++ b/crates/common/src/path.rs @@ -1,4 +1,5 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; +use urlencoding; pub fn replace_params_in_path( path: &str, @@ -7,6 +8,7 @@ pub fn replace_params_in_path( let mut result = String::new(); let mut in_param = false; let mut current_param = String::new(); + let mut vars_replaced = HashSet::new(); for c in path.chars() { if c == '{' { @@ -15,7 +17,9 @@ pub fn replace_params_in_path( in_param = false; let param_name = current_param.clone(); if let Some(value) = params.get(¶m_name) { - result.push_str(value); + let value = urlencoding::encode(value); + result.push_str(value.into_owned().as_str()); + vars_replaced.insert(param_name.clone()); } else { return Err(format!("Missing value for parameter `{}`", param_name)); } @@ -27,6 +31,18 @@ pub fn replace_params_in_path( } } + // add the remaining params in path + for (param_name, value) in params.iter() { + let value = urlencoding::encode(value); + if !vars_replaced.contains(param_name) { + if result.contains("?") { + result.push_str(&format!("&{}={}", param_name, value)); + } else { + result.push_str(&format!("?{}={}", param_name, value)); + } + } + } + Ok(result) } @@ -35,12 +51,18 @@ mod test { #[test] fn test_replace_path() { let path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}"; - let params = vec![("cluster_name".to_string(), "test1".to_string())] - .into_iter() - .collect(); + let params = vec![ + ("cluster_name".to_string(), "test1".to_string()), + ("hello".to_string(), "hello world".to_string()), + ] + .into_iter() + .collect(); assert_eq!( super::replace_params_in_path(path, ¶ms), - Ok("/cluster.open-cluster-management.io/v1/managedclusters/test1".to_string()) + Ok( + "/cluster.open-cluster-management.io/v1/managedclusters/test1?hello=hello%20world" + .to_string() + ) ); let path = "/cluster.open-cluster-management.io/v1/managedclusters"; diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index f5ba8af8..c0707bd6 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -348,7 +348,6 @@ impl StreamContext { &path, headers.into_iter().collect(), None, - // Some(tool_params_json_str.as_bytes()), vec![], Duration::from_secs(5), ); @@ -357,7 +356,6 @@ impl StreamContext { "dispatching api call to developer endpoint: {}, path: {}", endpoint.name, path ); - trace!("request body: {}", tool_params_json_str); callout_context.upstream_cluster = Some(endpoint.name.to_owned()); callout_context.upstream_cluster_path = Some(path.to_owned()); diff --git a/demos/spotify_demo/arch_config.yaml b/demos/spotify_demo/arch_config.yaml index 496b59a4..7dd8dcb2 100644 --- a/demos/spotify_demo/arch_config.yaml +++ b/demos/spotify_demo/arch_config.yaml @@ -94,7 +94,7 @@ prompt_targets: in_path: true endpoint: name: spotify - path: /v1/browse/new-releases?country={country}&limit=5 + path: /v1/browse/new-releases?limit=5 http_headers: Authorization: "Bearer $SPOTIFY_CLIENT_KEY" description: Get a list of new album releases featured in Spotify (shown, for example, on a Spotify player’s “Browse” tab).