Merge branch 'adil/add_in_path_support' into adil/update_getting_started_guide

This commit is contained in:
Adil Hafeez 2024-12-05 14:57:18 -08:00
commit 1a352166d8
7 changed files with 150 additions and 4 deletions

View file

@ -92,6 +92,8 @@ properties:
type: array
items:
type: string
in_path:
type: boolean
additionalProperties: false
required:
- name
@ -108,6 +110,11 @@ properties:
required:
- name
- path
http_method:
type: string
enum:
- GET
- POST
system_prompt:
type: string
additionalProperties: false

View file

@ -49,7 +49,9 @@ def validate_and_render_schema():
if "prompt_targets" in config_yaml:
for prompt_target in config_yaml["prompt_targets"]:
name = prompt_target.get("endpoint", {}).get("name", "")
name = prompt_target.get("endpoint", {}).get("name", None)
if not name:
continue
if name not in inferred_clusters:
inferred_clusters[name] = {
"name": name,

View file

@ -179,8 +179,6 @@ impl Display for LlmProvider {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
pub endpoint: Option<String>,
// pub connect_timeout: Option<DurationString>,
// pub timeout: Option<DurationString>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -193,6 +191,7 @@ pub struct Parameter {
#[serde(rename = "enum")]
pub enum_values: Option<Vec<String>>,
pub default: Option<String>,
pub in_path: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -201,11 +200,31 @@ pub struct EndpointDetails {
pub path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
pub enum HttpMethod {
#[serde(rename = "GET")]
Get,
#[default]
#[serde(rename = "POST")]
Post,
}
impl Display for HttpMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HttpMethod::Get => write!(f, "GET"),
HttpMethod::Post => write!(f, "POST"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptTarget {
pub name: String,
pub default: Option<bool>,
pub description: String,
#[serde(rename = "http_method")]
pub method: Option<HttpMethod>,
pub endpoint: Option<EndpointDetails>,
pub parameters: Option<Vec<Parameter>>,
pub system_prompt: Option<String>,

View file

@ -11,3 +11,4 @@ pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod tracing;
pub mod path;

82
crates/common/src/path.rs Normal file
View file

@ -0,0 +1,82 @@
use std::collections::HashMap;
pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> Result<String, String> {
let mut result = String::new();
let mut in_param = false;
let mut current_param = String::new();
for c in path.chars() {
if c == '{' {
in_param = true;
} else if c == '}' {
in_param = false;
let param_name = current_param.clone();
if let Some(value) = params.get(&param_name) {
result.push_str(value);
} else {
return Err(format!("Missing value for parameter `{}`", param_name));
}
current_param.clear();
} else {
if in_param {
current_param.push(c);
} else {
result.push(c);
}
}
}
Ok(result)
}
#[cfg(test)]
mod test {
#[test]
fn test_replace_path() {
let path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}";
let params = vec![("cluster_name".to_string(), "test1".to_string())]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/cluster.open-cluster-management.io/v1/managedclusters/test1".to_string())
);
let path = "/cluster.open-cluster-management.io/v1/managedclusters";
let params = vec![].into_iter().collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/cluster.open-cluster-management.io/v1/managedclusters".to_string())
);
let path = "/foo/{bar}/baz";
let params = vec![("bar".to_string(), "qux".to_string())]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/foo/qux/baz".to_string())
);
let path = "/foo/{bar}/baz/{qux}";
let params = vec![
("bar".to_string(), "qux".to_string()),
("qux".to_string(), "quux".to_string()),
]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/foo/qux/baz/quux".to_string())
);
let path = "/foo/{bar}/baz/{qux}";
let params = vec![("bar".to_string(), "qux".to_string())]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Err("Missing value for parameter `qux`".to_string())
);
}
}

View file

@ -29,6 +29,7 @@ use derivative::Derivative;
use http::StatusCode;
use log::{debug, info, trace, warn};
use proxy_wasm::traits::*;
use serde_yaml::Value;
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
@ -892,9 +893,37 @@ impl StreamContext {
let endpoint = prompt_target.endpoint.unwrap();
let path: String = endpoint.path.unwrap_or(String::from("/"));
// only add params that are of string, number and bool type
let url_params = tool_params
.iter()
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
.map(|(key, value)| match value {
Value::Number(n) => (key.clone(), n.to_string()),
Value::String(s) => (key.clone(), s.clone()),
Value::Bool(b) => (key.clone(), b.to_string()),
Value::Null => todo!(),
Value::Sequence(_) => todo!(),
Value::Mapping(_) => todo!(),
Value::Tagged(_) => todo!(),
})
.collect::<HashMap<String, String>>();
let path = match common::path::replace_params_in_path(&path, &url_params) {
Ok(path) => path,
Err(e) => {
return self.send_server_error(
ServerError::BadRequest {
why: format!("error replacing params in path: {}", e),
},
Some(StatusCode::BAD_REQUEST),
);
}
};
let http_method = prompt_target.method.unwrap_or_default().to_string();
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, endpoint.name.as_str()),
(":method", "POST"),
(":method", &http_method),
(":path", &path),
(":authority", endpoint.name.as_str()),
("content-type", "application/json"),

View file

@ -220,6 +220,12 @@ async def hallucination(req: HallucinationRequest, res: Response):
if "messages" in req.parameters:
req.parameters.pop("messages")
if not req.parameters or len(req.parameters) == 0:
return {
"params_scores": {},
"model": req.model,
}
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}
predictions = classifier(