add demo for s3

This commit is contained in:
Adil Hafeez 2025-02-11 17:35:30 -08:00
parent 0ea237fbac
commit 6e953ad5ae
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
11 changed files with 217 additions and 17 deletions

View file

@ -216,6 +216,7 @@ pub struct Parameter {
pub enum_values: Option<Vec<String>>,
pub default: Option<String>,
pub in_path: Option<bool>,
pub url_encode: Option<bool>,
pub format: Option<String>,
}

View file

@ -7,7 +7,7 @@ use crate::configuration::Parameter;
pub fn replace_params_in_path(
path: &str,
tool_params: &HashMap<String, String>,
prompt_target_params: &[Parameter],
prompt_target_params: &HashMap<String, Parameter>,
) -> Result<(String, String, HashMap<String, String>), String> {
let mut query_string_replaced = String::new();
let mut current_param = String::new();
@ -22,8 +22,16 @@ pub fn replace_params_in_path(
in_param = false;
let param_name = current_param.clone();
if let Some(value) = tool_params.get(&param_name) {
let value = urlencoding::encode(value);
query_string_replaced.push_str(value.into_owned().as_str());
let should_url_encode = prompt_target_params
.get(&param_name)
.map(|param| param.url_encode.unwrap_or_default())
.unwrap_or_default();
if should_url_encode {
let value = urlencoding::encode(value);
query_string_replaced.push_str(value.into_owned().as_str());
} else {
query_string_replaced.push_str(value);
}
vars_replaced.insert(param_name.clone());
} else {
return Err(format!("Missing value for parameter `{}`", param_name));
@ -51,7 +59,7 @@ pub fn replace_params_in_path(
}
// add default values
for param in prompt_target_params.iter() {
for param in prompt_target_params.values() {
if !vars_replaced.contains(&param.name) && param.default.is_some() {
params.insert(param.name.clone(), param.default.clone().unwrap());
if query_string_replaced.contains("?") {
@ -104,7 +112,11 @@ mod test {
default: Some("US".to_string()),
in_path: None,
format: None,
}];
url_encode: None,
}]
.into_iter()
.map(|param| (param.name.clone(), param))
.collect();
let out_params: HashMap<String, String> = vec![
("country".to_string(), "US".to_string()),
@ -122,7 +134,7 @@ mod test {
);
let out_params = HashMap::new();
let prompt_target_params = vec![];
let prompt_target_params: HashMap<String, Parameter> = HashMap::new();
let path = "/cluster.open-cluster-management.io/v1/managedclusters";
let params = vec![].into_iter().collect();
assert_eq!(

View file

@ -87,8 +87,9 @@ impl StreamContext {
));
debug!(
"request received: llm provider hint: {:?}, selected llm: {}",
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER),
"request received: llm provider hint: {}, selected llm: {}",
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.unwrap_or_default(),
self.llm_provider.as_ref().unwrap().name
);
}

View file

@ -144,7 +144,10 @@ impl HttpContext for StreamContext {
if metadata.is_none() {
metadata = Some(HashMap::new());
}
metadata.as_mut().unwrap().insert("optimize_context_window".to_string(), "true".to_string());
metadata
.as_mut()
.unwrap()
.insert("optimize_context_window".to_string(), "true".to_string());
}
}

View file

@ -4,7 +4,7 @@ use common::api::open_ai::{
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
ChatCompletionsResponse, Message, ModelServerResponse, ToolCall,
};
use common::configuration::{Overrides, PromptTarget, Tracing};
use common::configuration::{Overrides, Parameter, PromptTarget, Tracing};
use common::consts::{
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE,
@ -89,7 +89,7 @@ impl StreamContext {
streaming_response: false,
user_prompt: None,
is_chat_completions_request: false,
overrides: overrides,
overrides,
request_id: None,
traceparent: None,
_tracing: tracing,
@ -191,6 +191,10 @@ impl StreamContext {
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name =
Some(default_prompt_target.name.clone());
debug!(
"prompt target name: {}",
callout_context.prompt_target_name.as_ref().unwrap()
);
if let Err(e) = self.http_call(call_args, callout_context) {
warn!("error dispatching default prompt target request: {}", e);
@ -267,6 +271,10 @@ impl StreamContext {
// update prompt target name from the tool call response
callout_context.prompt_target_name =
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
debug!(
"prompt target name: {}",
callout_context.prompt_target_name.as_ref().unwrap()
);
self.schedule_api_call_request(callout_context);
}
@ -283,7 +291,13 @@ impl StreamContext {
.to_string();
let http_method = endpoint_details.method.clone().unwrap_or_default();
let prompt_target_params = prompt_target.parameters.clone().unwrap_or_default();
let prompt_target_params: HashMap<String, Parameter> = prompt_target
.parameters
.clone()
.unwrap_or_default()
.into_iter()
.map(|param| (param.name.clone(), param))
.collect();
let (path, body) = match compute_request_path_body(
&endpoint_path,

View file

@ -23,7 +23,7 @@ pub fn filter_tool_params(tool_params: &HashMap<String, Value>) -> HashMap<Strin
pub fn compute_request_path_body(
endpoint_path: &str,
tool_params: &HashMap<String, Value>,
prompt_target_params: &[Parameter],
prompt_target_params: &HashMap<String, Parameter>,
http_method: &HttpMethod,
) -> Result<(String, Option<String>), String> {
let tool_url_params = filter_tool_params(tool_params);
@ -55,6 +55,8 @@ pub fn compute_request_path_body(
#[cfg(test)]
mod test {
use std::collections::HashMap;
use common::configuration::{HttpMethod, Parameter};
#[test]
@ -76,7 +78,11 @@ mod test {
default: Some("US".to_string()),
in_path: None,
format: None,
}];
url_encode: None,
}]
.into_iter()
.map(|param| (param.name.clone(), param))
.collect::<HashMap<String, Parameter>>();
let http_method = HttpMethod::Get;
let (path, body) = super::compute_request_path_body(
endpoint_path,
@ -105,7 +111,11 @@ mod test {
default: Some("US".to_string()),
in_path: None,
format: None,
}];
url_encode: None,
}]
.into_iter()
.map(|param| (param.name.clone(), param))
.collect::<HashMap<String, Parameter>>();
let http_method = HttpMethod::Get;
let (path, body) = super::compute_request_path_body(
endpoint_path,
@ -139,7 +149,11 @@ mod test {
default: Some("US".to_string()),
in_path: None,
format: None,
}];
url_encode: None,
}]
.into_iter()
.map(|param| (param.name.clone(), param))
.collect::<HashMap<String, Parameter>>();
let http_method = HttpMethod::Get;
let (path, body) = super::compute_request_path_body(
endpoint_path,