diff --git a/crates/prompt_gateway/src/lib.rs b/crates/prompt_gateway/src/lib.rs index 1acd4d6d..7e7a24f9 100644 --- a/crates/prompt_gateway/src/lib.rs +++ b/crates/prompt_gateway/src/lib.rs @@ -7,6 +7,7 @@ mod filter_context; mod http_context; mod metrics; mod stream_context; +mod tools; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 77608cca..24e37186 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -1,9 +1,10 @@ use crate::metrics::Metrics; +use crate::tools::compute_request_path_body; use common::api::open_ai::{ to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message, ModelServerResponse, ToolCall, }; -use common::configuration::{HttpMethod, Overrides, PromptTarget, Tracing}; +use common::configuration::{Overrides, PromptTarget, Tracing}; use common::consts::{ ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, @@ -16,7 +17,6 @@ use derivative::Derivative; use http::StatusCode; use log::{debug, trace, warn}; use proxy_wasm::traits::*; -use serde_yaml::Value; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; @@ -273,76 +273,41 @@ impl StreamContext { fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) { let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone(); + let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap(); + let tool_params = &self.tool_calls.as_ref().unwrap()[0].function.arguments; + let endpoint_details = prompt_target.endpoint.as_ref().unwrap(); + let endpoint_path: String = endpoint_details + .path + .as_ref() + .unwrap_or(&String::from("/")) + .to_string(); - let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); + let http_method = endpoint_details.method.clone().unwrap_or_default(); + let prompt_target_params = prompt_target.parameters.clone().unwrap_or_default(); - let tool_params = self.tool_calls.as_ref().unwrap()[0] - .function - .arguments - .clone(); - - let endpoint = prompt_target.endpoint.unwrap(); - let path: String = endpoint.path.unwrap_or(String::from("/")); - let prompt_target_params = prompt_target.parameters.unwrap_or_default(); - let http_method = endpoint.method.unwrap_or_default(); - - // only add params that are of string, number and bool type - let tool_url_params = tool_params - .iter() - .filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool()) - .map(|(key, value)| match value { - Value::Number(n) => (key.clone(), n.to_string()), - Value::String(s) => (key.clone(), s.clone()), - Value::Bool(b) => (key.clone(), b.to_string()), - Value::Null => todo!(), - Value::Sequence(_) => todo!(), - Value::Mapping(_) => todo!(), - Value::Tagged(_) => todo!(), - }) - .collect::>(); - - let (path_with_params, query_string, additional_params) = - match common::path::replace_params_in_path( - &path, - &tool_url_params, - &prompt_target_params, - ) { - Ok((path, query_string, additional_params)) => { - (path, query_string, additional_params) - } - Err(e) => { - return self.send_server_error( - ServerError::BadRequest { - why: format!("error replacing params in path: {}", e), - }, - Some(StatusCode::BAD_REQUEST), - ); - } - }; - - let (path, body) = match http_method { - HttpMethod::Get => (format!("{}?{}", path_with_params, query_string), None), - HttpMethod::Post => { - let mut additional_params = additional_params; - if !query_string.is_empty() { - query_string.split("&").for_each(|param| { - let mut parts = param.split("="); - let key = parts.next().unwrap(); - let value = parts.next().unwrap(); - additional_params.insert(key.to_string(), value.to_string()); - }); - } - let body = serde_json::to_string(&additional_params).unwrap(); - (path_with_params, Some(body)) + let (path, body) = match compute_request_path_body( + &endpoint_path, + tool_params, + &prompt_target_params, + &http_method, + ) { + Ok((path, body)) => (path, body), + Err(e) => { + return self.send_server_error( + ServerError::BadRequest { + why: format!("error computing api request path or body: {}", e), + }, + Some(StatusCode::BAD_REQUEST), + ); } }; let http_method_str = http_method.to_string(); let mut headers: HashMap<_, _> = [ - (ARCH_UPSTREAM_HOST_HEADER, endpoint.name.as_str()), + (ARCH_UPSTREAM_HOST_HEADER, endpoint_details.name.as_str()), (":method", &http_method_str), (":path", &path), - (":authority", endpoint.name.as_str()), + (":authority", endpoint_details.name.as_str()), ("content-type", "application/json"), ("x-envoy-max-retries", "3"), ] @@ -368,10 +333,10 @@ impl StreamContext { debug!( "dispatching api call to developer endpoint: {}, path: {}, method: {}", - endpoint.name, path, http_method_str + endpoint_details.name, path, http_method_str ); - callout_context.upstream_cluster = Some(endpoint.name.to_owned()); + callout_context.upstream_cluster = Some(endpoint_details.name.to_owned()); callout_context.upstream_cluster_path = Some(path.to_owned()); callout_context.response_handler_type = ResponseHandlerType::FunctionCall; diff --git a/crates/prompt_gateway/src/tools.rs b/crates/prompt_gateway/src/tools.rs new file mode 100644 index 00000000..5ef8b1fb --- /dev/null +++ b/crates/prompt_gateway/src/tools.rs @@ -0,0 +1,54 @@ +use common::configuration::{HttpMethod, Parameter}; +use std::collections::HashMap; + +use serde_yaml::Value; + +// only add params that are of string, number and bool type +pub fn filter_tool_params(tool_params: &HashMap) -> HashMap { + tool_params + .iter() + .filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool()) + .map(|(key, value)| match value { + Value::Number(n) => (key.clone(), n.to_string()), + Value::String(s) => (key.clone(), s.clone()), + Value::Bool(b) => (key.clone(), b.to_string()), + Value::Null => todo!(), + Value::Sequence(_) => todo!(), + Value::Mapping(_) => todo!(), + Value::Tagged(_) => todo!(), + }) + .collect::>() +} + +pub fn compute_request_path_body( + endpoint_path: &str, + tool_params: &HashMap, + prompt_target_params: &[Parameter], + http_method: &HttpMethod, +) -> Result<(String, Option), String> { + let tool_url_params = filter_tool_params(tool_params); + let (path_with_params, query_string, additional_params) = common::path::replace_params_in_path( + endpoint_path, + &tool_url_params, + prompt_target_params, + )?; + + let (path, body) = match http_method { + HttpMethod::Get => (format!("{}?{}", path_with_params, query_string), None), + HttpMethod::Post => { + let mut additional_params = additional_params; + if !query_string.is_empty() { + query_string.split("&").for_each(|param| { + let mut parts = param.split("="); + let key = parts.next().unwrap(); + let value = parts.next().unwrap(); + additional_params.insert(key.to_string(), value.to_string()); + }); + } + let body = serde_json::to_string(&additional_params).unwrap(); + (path_with_params, Some(body)) + } + }; + + Ok((path, body)) +}