mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Merge branch 'main' into adil/add_endpoint_http_headers
This commit is contained in:
commit
47dcbf9469
6 changed files with 190 additions and 4403 deletions
|
|
@ -25,9 +25,7 @@ repos:
|
|||
name: cargo-test
|
||||
language: system
|
||||
types: [file, rust]
|
||||
# --lib is to only test the library, since when integration tests are made,
|
||||
# they will be in a seperate tests directory
|
||||
entry: bash -c "cd crates/llm_gateway && cargo test --lib"
|
||||
entry: bash -c "cd crates && cargo test --lib"
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.1.0
|
||||
|
|
|
|||
2166
crates/llm_gateway/Cargo.lock
generated
2166
crates/llm_gateway/Cargo.lock
generated
File diff suppressed because it is too large
Load diff
2166
crates/prompt_gateway/Cargo.lock
generated
2166
crates/prompt_gateway/Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -7,6 +7,7 @@ mod filter_context;
|
|||
mod http_context;
|
||||
mod metrics;
|
||||
mod stream_context;
|
||||
mod tools;
|
||||
|
||||
proxy_wasm::main! {{
|
||||
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
use crate::metrics::Metrics;
|
||||
use crate::tools::compute_request_path_body;
|
||||
use common::api::open_ai::{
|
||||
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, Message, ModelServerResponse, ToolCall,
|
||||
};
|
||||
use common::configuration::{HttpMethod, Overrides, PromptTarget, Tracing};
|
||||
use common::configuration::{Overrides, PromptTarget, Tracing};
|
||||
use common::consts::{
|
||||
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE,
|
||||
|
|
@ -16,7 +17,6 @@ use derivative::Derivative;
|
|||
use http::StatusCode;
|
||||
use log::{debug, trace, warn};
|
||||
use proxy_wasm::traits::*;
|
||||
use serde_yaml::Value;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
|
@ -273,78 +273,41 @@ impl StreamContext {
|
|||
|
||||
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
|
||||
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
|
||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap();
|
||||
let tool_params = &self.tool_calls.as_ref().unwrap()[0].function.arguments;
|
||||
let endpoint_details = prompt_target.endpoint.as_ref().unwrap();
|
||||
let endpoint_path: String = endpoint_details
|
||||
.path
|
||||
.as_ref()
|
||||
.unwrap_or(&String::from("/"))
|
||||
.to_string();
|
||||
|
||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
|
||||
let http_method = endpoint_details.method.clone().unwrap_or_default();
|
||||
let prompt_target_params = prompt_target.parameters.clone().unwrap_or_default();
|
||||
|
||||
let tool_params = self.tool_calls.as_ref().unwrap()[0]
|
||||
.function
|
||||
.arguments
|
||||
.clone();
|
||||
|
||||
let endpoint = prompt_target.endpoint.unwrap();
|
||||
let path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
let prompt_target_params = prompt_target.parameters.unwrap_or_default();
|
||||
let http_method = endpoint.method.unwrap_or_default();
|
||||
|
||||
// only add params that are of string, number and bool type
|
||||
let tool_url_params = tool_params
|
||||
.iter()
|
||||
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
|
||||
.map(|(key, value)| match value {
|
||||
Value::Number(n) => (key.clone(), n.to_string()),
|
||||
Value::String(s) => (key.clone(), s.clone()),
|
||||
Value::Bool(b) => (key.clone(), b.to_string()),
|
||||
Value::Null => todo!(),
|
||||
Value::Sequence(_) => todo!(),
|
||||
Value::Mapping(_) => todo!(),
|
||||
Value::Tagged(_) => todo!(),
|
||||
})
|
||||
.collect::<HashMap<String, String>>();
|
||||
|
||||
let (path_with_params, query_string, additional_params) =
|
||||
match common::path::replace_params_in_path(
|
||||
&path,
|
||||
&tool_url_params,
|
||||
&prompt_target_params,
|
||||
) {
|
||||
Ok((path, query_string, additional_params)) => {
|
||||
(path, query_string, additional_params)
|
||||
}
|
||||
Err(e) => {
|
||||
return self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("error replacing params in path: {}", e),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let (path, body) = match http_method {
|
||||
HttpMethod::Get => {
|
||||
(format!("{}?{}", path_with_params, query_string), None)
|
||||
}
|
||||
HttpMethod::Post => {
|
||||
let mut additional_params = additional_params;
|
||||
if !query_string.is_empty() {
|
||||
query_string.split("&").for_each(|param| {
|
||||
let mut parts = param.split("=");
|
||||
let key = parts.next().unwrap();
|
||||
let value = parts.next().unwrap();
|
||||
additional_params.insert(key.to_string(), value.to_string());
|
||||
});
|
||||
}
|
||||
let body = serde_json::to_string(&additional_params).unwrap();
|
||||
(path_with_params, Some(body))
|
||||
let (path, body) = match compute_request_path_body(
|
||||
&endpoint_path,
|
||||
tool_params,
|
||||
&prompt_target_params,
|
||||
&http_method,
|
||||
) {
|
||||
Ok((path, body)) => (path, body),
|
||||
Err(e) => {
|
||||
return self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("error computing api request path or body: {}", e),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let http_method_str = http_method.to_string();
|
||||
let mut headers: HashMap<_, _> = [
|
||||
(ARCH_UPSTREAM_HOST_HEADER, endpoint.name.as_str()),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, endpoint_details.name.as_str()),
|
||||
(":method", &http_method_str),
|
||||
(":path", &path),
|
||||
(":authority", endpoint.name.as_str()),
|
||||
(":authority", endpoint_details.name.as_str()),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
]
|
||||
|
|
@ -360,7 +323,7 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
// override http headers that are set in the prompt target
|
||||
let http_headers = endpoint.http_headers.unwrap_or_default();
|
||||
let http_headers = endpoint_details.http_headers.clone().unwrap_or_default();
|
||||
for (key, value) in http_headers.iter() {
|
||||
headers.insert(key.as_str(), value.as_str());
|
||||
}
|
||||
|
|
@ -376,10 +339,10 @@ impl StreamContext {
|
|||
|
||||
debug!(
|
||||
"dispatching api call to developer endpoint: {}, path: {}, method: {}",
|
||||
endpoint.name, path, http_method_str
|
||||
endpoint_details.name, path, http_method_str
|
||||
);
|
||||
|
||||
callout_context.upstream_cluster = Some(endpoint.name.to_owned());
|
||||
callout_context.upstream_cluster = Some(endpoint_details.name.to_owned());
|
||||
callout_context.upstream_cluster_path = Some(path.to_owned());
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||
|
||||
|
|
|
|||
157
crates/prompt_gateway/src/tools.rs
Normal file
157
crates/prompt_gateway/src/tools.rs
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
use common::configuration::{HttpMethod, Parameter};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde_yaml::Value;
|
||||
|
||||
// only add params that are of string, number and bool type
|
||||
pub fn filter_tool_params(tool_params: &HashMap<String, Value>) -> HashMap<String, String> {
|
||||
tool_params
|
||||
.iter()
|
||||
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
|
||||
.map(|(key, value)| match value {
|
||||
Value::Number(n) => (key.clone(), n.to_string()),
|
||||
Value::String(s) => (key.clone(), s.clone()),
|
||||
Value::Bool(b) => (key.clone(), b.to_string()),
|
||||
Value::Null => todo!(),
|
||||
Value::Sequence(_) => todo!(),
|
||||
Value::Mapping(_) => todo!(),
|
||||
Value::Tagged(_) => todo!(),
|
||||
})
|
||||
.collect::<HashMap<String, String>>()
|
||||
}
|
||||
|
||||
pub fn compute_request_path_body(
|
||||
endpoint_path: &str,
|
||||
tool_params: &HashMap<String, Value>,
|
||||
prompt_target_params: &[Parameter],
|
||||
http_method: &HttpMethod,
|
||||
) -> Result<(String, Option<String>), String> {
|
||||
let tool_url_params = filter_tool_params(tool_params);
|
||||
let (path_with_params, query_string, additional_params) = common::path::replace_params_in_path(
|
||||
endpoint_path,
|
||||
&tool_url_params,
|
||||
prompt_target_params,
|
||||
)?;
|
||||
|
||||
let (path, body) = match http_method {
|
||||
HttpMethod::Get => (format!("{}?{}", path_with_params, query_string), None),
|
||||
HttpMethod::Post => {
|
||||
let mut additional_params = additional_params;
|
||||
if !query_string.is_empty() {
|
||||
query_string.split("&").for_each(|param| {
|
||||
let mut parts = param.split("=");
|
||||
let key = parts.next().unwrap();
|
||||
let value = parts.next().unwrap();
|
||||
additional_params.insert(key.to_string(), value.to_string());
|
||||
});
|
||||
}
|
||||
let body = serde_json::to_string(&additional_params).unwrap();
|
||||
(path_with_params, Some(body))
|
||||
}
|
||||
};
|
||||
|
||||
Ok((path, body))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use common::configuration::{HttpMethod, Parameter};
|
||||
|
||||
#[test]
|
||||
fn test_compute_request_path_body() {
|
||||
let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}";
|
||||
let tool_params = serde_yaml::from_str(
|
||||
r#"
|
||||
cluster_name: test1
|
||||
hello: hello world
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let prompt_target_params = vec![Parameter {
|
||||
name: "country".to_string(),
|
||||
parameter_type: None,
|
||||
description: "test target".to_string(),
|
||||
required: None,
|
||||
enum_values: None,
|
||||
default: Some("US".to_string()),
|
||||
in_path: None,
|
||||
format: None,
|
||||
}];
|
||||
let http_method = HttpMethod::Get;
|
||||
let (path, body) = super::compute_request_path_body(
|
||||
endpoint_path,
|
||||
&tool_params,
|
||||
&prompt_target_params,
|
||||
&http_method,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
path,
|
||||
"/cluster.open-cluster-management.io/v1/managedclusters/test1?hello=hello%20world&country=US"
|
||||
);
|
||||
assert_eq!(body, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_request_path_body_empty_params() {
|
||||
let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/";
|
||||
let tool_params = serde_yaml::from_str(r#"{}"#).unwrap();
|
||||
let prompt_target_params = vec![Parameter {
|
||||
name: "country".to_string(),
|
||||
parameter_type: None,
|
||||
description: "test target".to_string(),
|
||||
required: None,
|
||||
enum_values: None,
|
||||
default: Some("US".to_string()),
|
||||
in_path: None,
|
||||
format: None,
|
||||
}];
|
||||
let http_method = HttpMethod::Get;
|
||||
let (path, body) = super::compute_request_path_body(
|
||||
endpoint_path,
|
||||
&tool_params,
|
||||
&prompt_target_params,
|
||||
&http_method,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
path,
|
||||
"/cluster.open-cluster-management.io/v1/managedclusters/?country=US"
|
||||
);
|
||||
assert_eq!(body, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_request_path_body_override_default_val() {
|
||||
let endpoint_path = "/cluster.open-cluster-management.io/v1/managedclusters/";
|
||||
let tool_params = serde_yaml::from_str(
|
||||
r#"
|
||||
country: UK
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let prompt_target_params = vec![Parameter {
|
||||
name: "country".to_string(),
|
||||
parameter_type: None,
|
||||
description: "test target".to_string(),
|
||||
required: None,
|
||||
enum_values: None,
|
||||
default: Some("US".to_string()),
|
||||
in_path: None,
|
||||
format: None,
|
||||
}];
|
||||
let http_method = HttpMethod::Get;
|
||||
let (path, body) = super::compute_request_path_body(
|
||||
endpoint_path,
|
||||
&tool_params,
|
||||
&prompt_target_params,
|
||||
&http_method,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
path,
|
||||
"/cluster.open-cluster-management.io/v1/managedclusters/?country=UK"
|
||||
);
|
||||
assert_eq!(body, None);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue