fix tests

This commit is contained in:
Adil Hafeez 2025-02-05 12:53:48 -08:00
parent 16f5025bc9
commit 7e9059f4f8
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
3 changed files with 64 additions and 23 deletions

View file

@ -1,9 +1,12 @@
use std::collections::{HashMap, HashSet};
use urlencoding;
use crate::configuration::Parameter;
pub fn replace_params_in_path(
path: &str,
params: &HashMap<String, String>,
tool_params: &HashMap<String, String>,
prompt_target_params: &Vec<Parameter>,
) -> Result<String, String> {
let mut result = String::new();
let mut in_param = false;
@ -16,7 +19,7 @@ pub fn replace_params_in_path(
} else if c == '}' {
in_param = false;
let param_name = current_param.clone();
if let Some(value) = params.get(&param_name) {
if let Some(value) = tool_params.get(&param_name) {
let value = urlencoding::encode(value);
result.push_str(value.into_owned().as_str());
vars_replaced.insert(param_name.clone());
@ -32,9 +35,10 @@ pub fn replace_params_in_path(
}
// add the remaining params in path
for (param_name, value) in params.iter() {
for (param_name, value) in tool_params.iter() {
let value = urlencoding::encode(value);
if !vars_replaced.contains(param_name) {
vars_replaced.insert(param_name.clone());
if result.contains("?") {
result.push_str(&format!("&{}={}", param_name, value));
} else {
@ -43,11 +47,32 @@ pub fn replace_params_in_path(
}
}
// add default values
for param in prompt_target_params.iter() {
if !vars_replaced.contains(&param.name) && param.default.is_some() {
if result.contains("?") {
result.push_str(&format!(
"&{}={}",
param.name,
param.default.as_ref().unwrap()
));
} else {
result.push_str(&format!(
"?{}={}",
param.name,
param.default.as_ref().unwrap()
));
}
}
}
Ok(result)
}
#[cfg(test)]
mod test {
use crate::configuration::Parameter;
#[test]
fn test_replace_path() {
let path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}";
@ -57,18 +82,31 @@ mod test {
]
.into_iter()
.collect();
let prompt_target_params = vec![Parameter {
name: "country".to_string(),
parameter_type: None,
description: "test target".to_string(),
required: None,
enum_values: None,
default: Some("US".to_string()),
in_path: None,
format: None,
}];
assert_eq!(
super::replace_params_in_path(path, &params),
super::replace_params_in_path(path, &params, &prompt_target_params),
Ok(
"/cluster.open-cluster-management.io/v1/managedclusters/test1?hello=hello%20world"
"/cluster.open-cluster-management.io/v1/managedclusters/test1?hello=hello%20world&country=US"
.to_string()
)
);
let prompt_target_params = vec![];
let path = "/cluster.open-cluster-management.io/v1/managedclusters";
let params = vec![].into_iter().collect();
assert_eq!(
super::replace_params_in_path(path, &params),
super::replace_params_in_path(path, &params, &prompt_target_params),
Ok("/cluster.open-cluster-management.io/v1/managedclusters".to_string())
);
@ -77,7 +115,7 @@ mod test {
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
super::replace_params_in_path(path, &params, &prompt_target_params),
Ok("/foo/qux/baz".to_string())
);
@ -89,7 +127,7 @@ mod test {
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
super::replace_params_in_path(path, &params, &prompt_target_params),
Ok("/foo/qux/baz/quux".to_string())
);
@ -98,7 +136,7 @@ mod test {
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
super::replace_params_in_path(path, &params, &prompt_target_params),
Err("Missing value for parameter `qux`".to_string())
);
}

View file

@ -276,22 +276,17 @@ impl StreamContext {
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
let mut tool_params = self.tool_calls.as_ref().unwrap()[0]
let tool_params = self.tool_calls.as_ref().unwrap()[0]
.function
.arguments
.clone();
tool_params.insert(
String::from(MESSAGES_KEY),
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
);
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
let endpoint = prompt_target.endpoint.unwrap();
let path: String = endpoint.path.unwrap_or(String::from("/"));
let prompt_target_params = prompt_target.parameters.unwrap_or_default();
// only add params that are of string, number and bool type
let url_params = tool_params
let tool_url_params = tool_params
.iter()
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
.map(|(key, value)| match value {
@ -305,7 +300,7 @@ impl StreamContext {
})
.collect::<HashMap<String, String>>();
let path = match common::path::replace_params_in_path(&path, &url_params) {
let path = match common::path::replace_params_in_path(&path, &tool_url_params, &prompt_target_params) {
Ok(path) => path,
Err(e) => {
return self.send_server_error(

View file

@ -99,19 +99,27 @@ prompt_targets:
Authorization: "Bearer $SPOTIFY_CLIENT_KEY"
description: Get a list of new album releases featured in Spotify (shown, for example, on a Spotify players “Browse” tab).
- name: search_for_shows_and_podcasts
- name: get_catalog_information
parameters:
- name: q
description: The search filter to narrow down results
description: keywords to search about catalog
required: true
type: str
- name: type
type: str
description: The type of catalog item
default: show
enum:
- album
- artist
- playlist
- track
- show
- episode
- audiobook
required: true
- name: market
type: str
description: A country code
description: A country code for catalog
default: US
- name: limit
type: integer
@ -122,4 +130,4 @@ prompt_targets:
path: /v1/search
http_headers:
Authorization: "Bearer $SPOTIFY_CLIENT_KEY"
description: get Spotify catalog information about shows and podcasts that match the keyword filter.
description: Get Spotify catalog information about albums, artists, playlists, tracks, shows, episodes or audiobooks that match a keyword string.