Add demo for acm

This commit is contained in:
Adil Hafeez 2024-12-03 19:22:31 -08:00
parent a0c159c9ba
commit 4343387adc
44 changed files with 3194 additions and 3 deletions

View file

@ -179,8 +179,6 @@ impl Display for LlmProvider {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
pub endpoint: Option<String>,
// pub connect_timeout: Option<DurationString>,
// pub timeout: Option<DurationString>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -193,6 +191,7 @@ pub struct Parameter {
#[serde(rename = "enum")]
pub enum_values: Option<Vec<String>>,
pub default: Option<String>,
pub in_path: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -201,11 +200,31 @@ pub struct EndpointDetails {
pub path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
pub enum HttpMethod {
#[serde(rename = "GET")]
Get,
#[default]
#[serde(rename = "POST")]
Post,
}
impl Display for HttpMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HttpMethod::Get => write!(f, "GET"),
HttpMethod::Post => write!(f, "POST"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptTarget {
pub name: String,
pub default: Option<bool>,
pub description: String,
#[serde(rename = "http_method")]
pub method: Option<HttpMethod>,
pub endpoint: Option<EndpointDetails>,
pub parameters: Option<Vec<Parameter>>,
pub system_prompt: Option<String>,

View file

@ -11,3 +11,4 @@ pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod tracing;
pub mod path;

82
crates/common/src/path.rs Normal file
View file

@ -0,0 +1,82 @@
use std::collections::HashMap;
pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> Result<String, String> {
let mut result = String::new();
let mut in_param = false;
let mut current_param = String::new();
for c in path.chars() {
if c == '{' {
in_param = true;
} else if c == '}' {
in_param = false;
let param_name = current_param.clone();
if let Some(value) = params.get(&param_name) {
result.push_str(value);
} else {
return Err(format!("Missing value for parameter `{}`", param_name));
}
current_param.clear();
} else {
if in_param {
current_param.push(c);
} else {
result.push(c);
}
}
}
Ok(result)
}
#[cfg(test)]
mod test {
#[test]
fn test_replace_path() {
let path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}";
let params = vec![("cluster_name".to_string(), "test1".to_string())]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/cluster.open-cluster-management.io/v1/managedclusters/test1".to_string())
);
let path = "/cluster.open-cluster-management.io/v1/managedclusters";
let params = vec![].into_iter().collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/cluster.open-cluster-management.io/v1/managedclusters".to_string())
);
let path = "/foo/{bar}/baz";
let params = vec![("bar".to_string(), "qux".to_string())]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/foo/qux/baz".to_string())
);
let path = "/foo/{bar}/baz/{qux}";
let params = vec![
("bar".to_string(), "qux".to_string()),
("qux".to_string(), "quux".to_string()),
]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/foo/qux/baz/quux".to_string())
);
let path = "/foo/{bar}/baz/{qux}";
let params = vec![("bar".to_string(), "qux".to_string())]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Err("Missing value for parameter `qux`".to_string())
);
}
}

View file

@ -29,6 +29,7 @@ use derivative::Derivative;
use http::StatusCode;
use log::{debug, info, trace, warn};
use proxy_wasm::traits::*;
use serde_yaml::Value;
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
@ -892,9 +893,38 @@ impl StreamContext {
let endpoint = prompt_target.endpoint.unwrap();
let path: String = endpoint.path.unwrap_or(String::from("/"));
// only add params that are of string and number type
let url_params = tool_params
.iter()
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
.map(|(key, value)| match value {
Value::Number(n) => (key.clone(), n.to_string()),
Value::String(s) => (key.clone(), s.clone()),
Value::Bool(b) => (key.clone(), b.to_string()),
Value::Null => todo!(),
Value::Sequence(_) => todo!(),
Value::Mapping(_) => todo!(),
Value::Tagged(_) => todo!(),
})
.collect::<HashMap<String, String>>();
let path = match common::path::replace_params_in_path(&path, &url_params) {
Ok(path) => path,
Err(e) => {
return self.send_server_error(
ServerError::BadRequest {
why: format!("error replacing params in path: {}", e),
},
Some(StatusCode::BAD_REQUEST),
);
}
};
let http_method = prompt_target.method.unwrap_or_default().to_string();
info!("http_method: {}", http_method);
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, endpoint.name.as_str()),
(":method", "POST"),
(":method", &http_method),
(":path", &path),
(":authority", endpoint.name.as_str()),
("content-type", "application/json"),