diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index a11e9562..0f7e55ac 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -92,6 +92,8 @@ properties: type: array items: type: string + in_path: + type: boolean additionalProperties: false required: - name @@ -108,6 +110,11 @@ properties: required: - name - path + http_method: + type: string + enum: + - GET + - POST system_prompt: type: string additionalProperties: false diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 293ffc8a..c32dae16 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -49,7 +49,9 @@ def validate_and_render_schema(): if "prompt_targets" in config_yaml: for prompt_target in config_yaml["prompt_targets"]: - name = prompt_target.get("endpoint", {}).get("name", "") + name = prompt_target.get("endpoint", {}).get("name", None) + if not name: + continue if name not in inferred_clusters: inferred_clusters[name] = { "name": name, diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 543849c9..7486e3d9 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -179,8 +179,6 @@ impl Display for LlmProvider { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Endpoint { pub endpoint: Option, - // pub connect_timeout: Option, - // pub timeout: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -193,6 +191,7 @@ pub struct Parameter { #[serde(rename = "enum")] pub enum_values: Option>, pub default: Option, + pub in_path: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -201,11 +200,31 @@ pub struct EndpointDetails { pub path: Option, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)] +pub enum HttpMethod { + #[serde(rename = "GET")] + Get, + #[default] + #[serde(rename = "POST")] + Post, +} + +impl Display for HttpMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + HttpMethod::Get => write!(f, "GET"), + HttpMethod::Post => write!(f, "POST"), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PromptTarget { pub name: String, pub default: Option, pub description: String, + #[serde(rename = "http_method")] + pub method: Option, pub endpoint: Option, pub parameters: Option>, pub system_prompt: Option, diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index a41badf9..aa34f2fd 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -11,3 +11,4 @@ pub mod routing; pub mod stats; pub mod tokenizer; pub mod tracing; +pub mod path; diff --git a/crates/common/src/path.rs b/crates/common/src/path.rs new file mode 100644 index 00000000..2b289c9d --- /dev/null +++ b/crates/common/src/path.rs @@ -0,0 +1,82 @@ +use std::collections::HashMap; + +pub fn replace_params_in_path(path: &str, params: &HashMap) -> Result { + let mut result = String::new(); + let mut in_param = false; + let mut current_param = String::new(); + + for c in path.chars() { + if c == '{' { + in_param = true; + } else if c == '}' { + in_param = false; + let param_name = current_param.clone(); + if let Some(value) = params.get(¶m_name) { + result.push_str(value); + } else { + return Err(format!("Missing value for parameter `{}`", param_name)); + } + current_param.clear(); + } else { + if in_param { + current_param.push(c); + } else { + result.push(c); + } + } + } + + Ok(result) +} + +#[cfg(test)] +mod test { + #[test] + fn test_replace_path() { + let path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}"; + let params = vec![("cluster_name".to_string(), "test1".to_string())] + .into_iter() + .collect(); + assert_eq!( + super::replace_params_in_path(path, ¶ms), + Ok("/cluster.open-cluster-management.io/v1/managedclusters/test1".to_string()) + ); + + let path = "/cluster.open-cluster-management.io/v1/managedclusters"; + let params = vec![].into_iter().collect(); + assert_eq!( + super::replace_params_in_path(path, ¶ms), + Ok("/cluster.open-cluster-management.io/v1/managedclusters".to_string()) + ); + + let path = "/foo/{bar}/baz"; + let params = vec![("bar".to_string(), "qux".to_string())] + .into_iter() + .collect(); + assert_eq!( + super::replace_params_in_path(path, ¶ms), + Ok("/foo/qux/baz".to_string()) + ); + + let path = "/foo/{bar}/baz/{qux}"; + let params = vec![ + ("bar".to_string(), "qux".to_string()), + ("qux".to_string(), "quux".to_string()), + ] + .into_iter() + .collect(); + assert_eq!( + super::replace_params_in_path(path, ¶ms), + Ok("/foo/qux/baz/quux".to_string()) + ); + + let path = "/foo/{bar}/baz/{qux}"; + let params = vec![("bar".to_string(), "qux".to_string())] + .into_iter() + .collect(); + assert_eq!( + super::replace_params_in_path(path, ¶ms), + Err("Missing value for parameter `qux`".to_string()) + ); + } +} diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 6df5a4d1..e134e07c 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -29,6 +29,7 @@ use derivative::Derivative; use http::StatusCode; use log::{debug, info, trace, warn}; use proxy_wasm::traits::*; +use serde_yaml::Value; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; @@ -892,9 +893,37 @@ impl StreamContext { let endpoint = prompt_target.endpoint.unwrap(); let path: String = endpoint.path.unwrap_or(String::from("/")); + // only add params that are of string, number and bool type + let url_params = tool_params + .iter() + .filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool()) + .map(|(key, value)| match value { + Value::Number(n) => (key.clone(), n.to_string()), + Value::String(s) => (key.clone(), s.clone()), + Value::Bool(b) => (key.clone(), b.to_string()), + Value::Null => todo!(), + Value::Sequence(_) => todo!(), + Value::Mapping(_) => todo!(), + Value::Tagged(_) => todo!(), + }) + .collect::>(); + + let path = match common::path::replace_params_in_path(&path, &url_params) { + Ok(path) => path, + Err(e) => { + return self.send_server_error( + ServerError::BadRequest { + why: format!("error replacing params in path: {}", e), + }, + Some(StatusCode::BAD_REQUEST), + ); + } + }; + + let http_method = prompt_target.method.unwrap_or_default().to_string(); let mut headers = vec![ (ARCH_UPSTREAM_HOST_HEADER, endpoint.name.as_str()), - (":method", "POST"), + (":method", &http_method), (":path", &path), (":authority", endpoint.name.as_str()), ("content-type", "application/json"), diff --git a/model_server/app/main.py b/model_server/app/main.py index 43be1f74..4f9337bd 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -220,6 +220,12 @@ async def hallucination(req: HallucinationRequest, res: Response): if "messages" in req.parameters: req.parameters.pop("messages") + if not req.parameters or len(req.parameters) == 0: + return { + "params_scores": {}, + "model": req.model, + } + candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()} predictions = classifier(