mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Merge branch 'adil/add_in_path_support' into adil/update_getting_started_guide
This commit is contained in:
commit
1a352166d8
7 changed files with 150 additions and 4 deletions
|
|
@ -92,6 +92,8 @@ properties:
|
|||
type: array
|
||||
items:
|
||||
type: string
|
||||
in_path:
|
||||
type: boolean
|
||||
additionalProperties: false
|
||||
required:
|
||||
- name
|
||||
|
|
@ -108,6 +110,11 @@ properties:
|
|||
required:
|
||||
- name
|
||||
- path
|
||||
http_method:
|
||||
type: string
|
||||
enum:
|
||||
- GET
|
||||
- POST
|
||||
system_prompt:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
|
|
|
|||
|
|
@ -49,7 +49,9 @@ def validate_and_render_schema():
|
|||
|
||||
if "prompt_targets" in config_yaml:
|
||||
for prompt_target in config_yaml["prompt_targets"]:
|
||||
name = prompt_target.get("endpoint", {}).get("name", "")
|
||||
name = prompt_target.get("endpoint", {}).get("name", None)
|
||||
if not name:
|
||||
continue
|
||||
if name not in inferred_clusters:
|
||||
inferred_clusters[name] = {
|
||||
"name": name,
|
||||
|
|
|
|||
|
|
@ -179,8 +179,6 @@ impl Display for LlmProvider {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Endpoint {
|
||||
pub endpoint: Option<String>,
|
||||
// pub connect_timeout: Option<DurationString>,
|
||||
// pub timeout: Option<DurationString>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -193,6 +191,7 @@ pub struct Parameter {
|
|||
#[serde(rename = "enum")]
|
||||
pub enum_values: Option<Vec<String>>,
|
||||
pub default: Option<String>,
|
||||
pub in_path: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -201,11 +200,31 @@ pub struct EndpointDetails {
|
|||
pub path: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
|
||||
pub enum HttpMethod {
|
||||
#[serde(rename = "GET")]
|
||||
Get,
|
||||
#[default]
|
||||
#[serde(rename = "POST")]
|
||||
Post,
|
||||
}
|
||||
|
||||
impl Display for HttpMethod {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
HttpMethod::Get => write!(f, "GET"),
|
||||
HttpMethod::Post => write!(f, "POST"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptTarget {
|
||||
pub name: String,
|
||||
pub default: Option<bool>,
|
||||
pub description: String,
|
||||
#[serde(rename = "http_method")]
|
||||
pub method: Option<HttpMethod>,
|
||||
pub endpoint: Option<EndpointDetails>,
|
||||
pub parameters: Option<Vec<Parameter>>,
|
||||
pub system_prompt: Option<String>,
|
||||
|
|
|
|||
|
|
@ -11,3 +11,4 @@ pub mod routing;
|
|||
pub mod stats;
|
||||
pub mod tokenizer;
|
||||
pub mod tracing;
|
||||
pub mod path;
|
||||
|
|
|
|||
82
crates/common/src/path.rs
Normal file
82
crates/common/src/path.rs
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> Result<String, String> {
|
||||
let mut result = String::new();
|
||||
let mut in_param = false;
|
||||
let mut current_param = String::new();
|
||||
|
||||
for c in path.chars() {
|
||||
if c == '{' {
|
||||
in_param = true;
|
||||
} else if c == '}' {
|
||||
in_param = false;
|
||||
let param_name = current_param.clone();
|
||||
if let Some(value) = params.get(¶m_name) {
|
||||
result.push_str(value);
|
||||
} else {
|
||||
return Err(format!("Missing value for parameter `{}`", param_name));
|
||||
}
|
||||
current_param.clear();
|
||||
} else {
|
||||
if in_param {
|
||||
current_param.push(c);
|
||||
} else {
|
||||
result.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
#[test]
|
||||
fn test_replace_path() {
|
||||
let path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}";
|
||||
let params = vec![("cluster_name".to_string(), "test1".to_string())]
|
||||
.into_iter()
|
||||
.collect();
|
||||
assert_eq!(
|
||||
super::replace_params_in_path(path, ¶ms),
|
||||
Ok("/cluster.open-cluster-management.io/v1/managedclusters/test1".to_string())
|
||||
);
|
||||
|
||||
let path = "/cluster.open-cluster-management.io/v1/managedclusters";
|
||||
let params = vec![].into_iter().collect();
|
||||
assert_eq!(
|
||||
super::replace_params_in_path(path, ¶ms),
|
||||
Ok("/cluster.open-cluster-management.io/v1/managedclusters".to_string())
|
||||
);
|
||||
|
||||
let path = "/foo/{bar}/baz";
|
||||
let params = vec![("bar".to_string(), "qux".to_string())]
|
||||
.into_iter()
|
||||
.collect();
|
||||
assert_eq!(
|
||||
super::replace_params_in_path(path, ¶ms),
|
||||
Ok("/foo/qux/baz".to_string())
|
||||
);
|
||||
|
||||
let path = "/foo/{bar}/baz/{qux}";
|
||||
let params = vec![
|
||||
("bar".to_string(), "qux".to_string()),
|
||||
("qux".to_string(), "quux".to_string()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect();
|
||||
assert_eq!(
|
||||
super::replace_params_in_path(path, ¶ms),
|
||||
Ok("/foo/qux/baz/quux".to_string())
|
||||
);
|
||||
|
||||
let path = "/foo/{bar}/baz/{qux}";
|
||||
let params = vec![("bar".to_string(), "qux".to_string())]
|
||||
.into_iter()
|
||||
.collect();
|
||||
assert_eq!(
|
||||
super::replace_params_in_path(path, ¶ms),
|
||||
Err("Missing value for parameter `qux`".to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -29,6 +29,7 @@ use derivative::Derivative;
|
|||
use http::StatusCode;
|
||||
use log::{debug, info, trace, warn};
|
||||
use proxy_wasm::traits::*;
|
||||
use serde_yaml::Value;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
|
@ -892,9 +893,37 @@ impl StreamContext {
|
|||
let endpoint = prompt_target.endpoint.unwrap();
|
||||
let path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
|
||||
// only add params that are of string, number and bool type
|
||||
let url_params = tool_params
|
||||
.iter()
|
||||
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
|
||||
.map(|(key, value)| match value {
|
||||
Value::Number(n) => (key.clone(), n.to_string()),
|
||||
Value::String(s) => (key.clone(), s.clone()),
|
||||
Value::Bool(b) => (key.clone(), b.to_string()),
|
||||
Value::Null => todo!(),
|
||||
Value::Sequence(_) => todo!(),
|
||||
Value::Mapping(_) => todo!(),
|
||||
Value::Tagged(_) => todo!(),
|
||||
})
|
||||
.collect::<HashMap<String, String>>();
|
||||
|
||||
let path = match common::path::replace_params_in_path(&path, &url_params) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
return self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("error replacing params in path: {}", e),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let http_method = prompt_target.method.unwrap_or_default().to_string();
|
||||
let mut headers = vec![
|
||||
(ARCH_UPSTREAM_HOST_HEADER, endpoint.name.as_str()),
|
||||
(":method", "POST"),
|
||||
(":method", &http_method),
|
||||
(":path", &path),
|
||||
(":authority", endpoint.name.as_str()),
|
||||
("content-type", "application/json"),
|
||||
|
|
|
|||
|
|
@ -220,6 +220,12 @@ async def hallucination(req: HallucinationRequest, res: Response):
|
|||
if "messages" in req.parameters:
|
||||
req.parameters.pop("messages")
|
||||
|
||||
if not req.parameters or len(req.parameters) == 0:
|
||||
return {
|
||||
"params_scores": {},
|
||||
"model": req.model,
|
||||
}
|
||||
|
||||
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}
|
||||
|
||||
predictions = classifier(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue