mirror of
https://github.com/katanemo/plano.git
synced 2026-06-29 15:49:40 +02:00
update config (#93)
This commit is contained in:
parent
4182879717
commit
cc35eb0cd7
13 changed files with 575 additions and 329 deletions
10
public_types/Cargo.lock
generated
10
public_types/Cargo.lock
generated
|
|
@ -8,6 +8,15 @@ version = "0.1.13"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"
|
||||
|
||||
[[package]]
|
||||
name = "duration-string"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fcc1d9ae294a15ed05aeae8e11ee5f2b3fe971c077d45a42fb20825fba6ee13"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
|
|
@ -65,6 +74,7 @@ dependencies = [
|
|||
name = "public_types"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"duration-string",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ edition = "2021"
|
|||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_yaml = "0.9.34"
|
||||
duration-string = { version = "0.3.0", features = ["serde"] }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = "1.4.1"
|
||||
|
|
|
|||
|
|
@ -151,11 +151,16 @@ pub mod open_ai {
|
|||
fn from(s: String) -> Self {
|
||||
match s.as_str() {
|
||||
"int" => ParameterType::Int,
|
||||
"integer" => ParameterType::Int,
|
||||
"float" => ParameterType::Float,
|
||||
"bool" => ParameterType::Bool,
|
||||
"boolean" => ParameterType::Bool,
|
||||
"str" => ParameterType::String,
|
||||
"string" => ParameterType::String,
|
||||
"list" => ParameterType::List,
|
||||
"array" => ParameterType::List,
|
||||
"dict" => ParameterType::Dict,
|
||||
"dictionary" => ParameterType::Dict,
|
||||
_ => ParameterType::String,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
|
||||
use duration_string::DurationString;
|
||||
use serde::{Deserialize, Serialize, Deserializer};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Overrides {
|
||||
|
|
@ -7,31 +10,88 @@ pub struct Overrides {
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub default_prompt_endpoint: String,
|
||||
pub load_balancing: LoadBalancing,
|
||||
pub timeout_ms: u64,
|
||||
pub overrides: Option<Overrides>,
|
||||
pub version: String,
|
||||
pub listener: Listener,
|
||||
pub endpoints: HashMap<String, Endpoint>,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub prompt_guards: Option<PromptGuards>,
|
||||
pub overrides: Option<Overrides>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub prompt_guards: Option<PromptGuards>,
|
||||
pub prompt_targets: Vec<PromptTarget>,
|
||||
pub error_target: Option<ErrorTargetDetail>,
|
||||
pub tracing: Option<i16>,
|
||||
pub ratelimits: Option<Vec<Ratelimit>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ErrorTargetDetail {
|
||||
pub endpoint: Option<EndpointDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Listener {
|
||||
pub address: String,
|
||||
pub port: u16,
|
||||
pub message_format: MessageFormat,
|
||||
// pub connect_timeout: Option<DurationString>,
|
||||
}
|
||||
|
||||
impl Default for Listener {
|
||||
fn default() -> Self {
|
||||
Listener {
|
||||
address: "".to_string(),
|
||||
port: 0,
|
||||
message_format: MessageFormat::default(),
|
||||
// connect_timeout: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub enum MessageFormat {
|
||||
#[serde(rename = "huggingface")]
|
||||
#[default]
|
||||
Huggingface,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PromptGuards {
|
||||
pub input_guards: InputGuards,
|
||||
pub input_guards: HashMap<GuardType, GuardOptions>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct InputGuards {
|
||||
pub jailbreak: Option<GuardOptions>,
|
||||
pub toxicity: Option<GuardOptions>,
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub enum GuardType {
|
||||
#[serde(rename = "jailbreak")]
|
||||
Jailbreak,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GuardOptions {
|
||||
pub on_exception_message: Option<String>,
|
||||
pub on_exception: Option<OnExceptionDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OnExceptionDetails {
|
||||
pub forward_to_error_target: Option<bool>,
|
||||
pub error_handler: Option<String>,
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmRatelimit {
|
||||
pub selector: LlmRatelimitSelector,
|
||||
pub limit: Limit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmRatelimitSelector {
|
||||
pub http_header: Option<RatelimitHeader>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub struct Header {
|
||||
pub key: String,
|
||||
pub value: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -58,19 +118,11 @@ pub enum TimeUnit {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub struct Header {
|
||||
pub key: String,
|
||||
pub struct RatelimitHeader {
|
||||
pub name: String,
|
||||
pub value: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum LoadBalancing {
|
||||
#[serde(rename = "round_robin")]
|
||||
RoundRobin,
|
||||
#[serde(rename = "random")]
|
||||
Random,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct EmbeddingProviver {
|
||||
|
|
@ -82,23 +134,19 @@ pub struct EmbeddingProviver {
|
|||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct LlmProvider {
|
||||
pub name: String,
|
||||
pub api_key: Option<String>,
|
||||
//TODO: handle env var replacement
|
||||
pub access_key: Option<String>,
|
||||
pub model: String,
|
||||
pub default: Option<bool>,
|
||||
pub endpoint: Option<EnpointType>,
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum EnpointType {
|
||||
String(String),
|
||||
Struct(Endpoint),
|
||||
pub stream: Option<bool>,
|
||||
pub rate_limits: Option<LlmRatelimit>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Endpoint {
|
||||
pub cluster: String,
|
||||
pub path: Option<String>,
|
||||
pub method: Option<String>,
|
||||
pub endpoint: Option<String>,
|
||||
// pub connect_timeout: Option<DurationString>,
|
||||
// pub timeout: Option<DurationString>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -114,82 +162,144 @@ pub struct Parameter {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum PromptType {
|
||||
#[serde(rename = "function_resolver")]
|
||||
FunctionResolver,
|
||||
pub struct EndpointDetails {
|
||||
pub name: String,
|
||||
pub path: Option<String>,
|
||||
pub method: Option<Method>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
pub enum Method {
|
||||
Get,
|
||||
Post,
|
||||
Put,
|
||||
Delete,
|
||||
}
|
||||
|
||||
impl ToString for Method {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Method::Get => "GET".to_string(),
|
||||
Method::Post => "POST".to_string(),
|
||||
Method::Put => "PUT".to_string(),
|
||||
Method::Delete => "DELETE".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Method {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
match s.to_uppercase().as_str() {
|
||||
"GET" => Ok(Method::Get),
|
||||
"POST" => Ok(Method::Post),
|
||||
"PUT" => Ok(Method::Put),
|
||||
"DELETE" => Ok(Method::Delete),
|
||||
_ => Err(serde::de::Error::custom(format!("Invalid enum variant: {}", s))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptTarget {
|
||||
#[serde(rename = "type")]
|
||||
pub prompt_type: PromptType,
|
||||
pub name: String,
|
||||
pub default: Option<bool>,
|
||||
pub description: String,
|
||||
pub endpoint: Option<EndpointDetails>,
|
||||
pub parameters: Option<Vec<Parameter>>,
|
||||
pub endpoint: Option<Endpoint>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub auto_llm_dispatch_on_response: Option<bool>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
pub const CONFIGURATION: &str = r#"
|
||||
default_prompt_endpoint: "127.0.0.1"
|
||||
load_balancing: "round_robin"
|
||||
timeout_ms: 5000
|
||||
use std::fs;
|
||||
|
||||
llm_providers:
|
||||
- name: "open-ai-gpt-4"
|
||||
api_key: "$OPEN_AI_API_KEY"
|
||||
model: gpt-4
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception_message: Looks like you are curious about my abilities…
|
||||
toxicity:
|
||||
on_exception_message: Looks like you are curious about my abilities…
|
||||
|
||||
prompt_targets:
|
||||
|
||||
- type: function_resolver
|
||||
name: weather_forecast
|
||||
description: Get the weather forecast for a location
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
parameters:
|
||||
- name: location
|
||||
required: true
|
||||
description: "The location for which the weather is requested"
|
||||
|
||||
- type: function_resolver
|
||||
name: weather_forecast_2
|
||||
description: Get the weather forecast for a location
|
||||
few_shot_examples:
|
||||
- what is the weather in New York?
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
parameters:
|
||||
- name: city
|
||||
description: "The location for which the weather is requested"
|
||||
|
||||
ratelimits:
|
||||
- provider: open-ai-gpt-4
|
||||
selector:
|
||||
key: x-katanemo-openai-limit-id
|
||||
limit:
|
||||
tokens: 100
|
||||
unit: minute
|
||||
"#;
|
||||
use crate::configuration::GuardType;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_configuration() {
|
||||
let _: super::Configuration = serde_yaml::from_str(CONFIGURATION).unwrap();
|
||||
let ref_config =
|
||||
fs::read_to_string("../docs/source/_config/prompt-config-full-reference.yml")
|
||||
.expect("reference config file not found");
|
||||
|
||||
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
|
||||
assert_eq!(config.version, "0.1-beta");
|
||||
|
||||
let open_ai_provider = config
|
||||
.llm_providers
|
||||
.iter()
|
||||
.find(|p| p.name.to_lowercase() == "openai")
|
||||
.unwrap();
|
||||
assert_eq!(open_ai_provider.name.to_lowercase(), "openai");
|
||||
assert_eq!(
|
||||
open_ai_provider.access_key,
|
||||
Some("$OPENAI_API_KEY".to_string())
|
||||
);
|
||||
assert_eq!(open_ai_provider.model, "gpt-4o");
|
||||
assert_eq!(open_ai_provider.default, Some(true));
|
||||
assert_eq!(open_ai_provider.stream, Some(true));
|
||||
|
||||
let prompt_guards = config.prompt_guards.as_ref().unwrap();
|
||||
let input_guards = &prompt_guards.input_guards;
|
||||
let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap();
|
||||
assert_eq!(
|
||||
jailbreak_guard
|
||||
.on_exception
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.forward_to_error_target,
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
jailbreak_guard.on_exception.as_ref().unwrap().error_handler,
|
||||
None
|
||||
);
|
||||
|
||||
let prompt_targets = &config.prompt_targets;
|
||||
assert_eq!(prompt_targets.len(), 2);
|
||||
let prompt_target = prompt_targets
|
||||
.iter()
|
||||
.find(|p| p.name == "reboot_network_device")
|
||||
.unwrap();
|
||||
assert_eq!(prompt_target.name, "reboot_network_device");
|
||||
assert_eq!(prompt_target.default, None);
|
||||
|
||||
let prompt_target = prompt_targets
|
||||
.iter()
|
||||
.find(|p| p.name == "information_extraction")
|
||||
.unwrap();
|
||||
assert_eq!(prompt_target.name, "information_extraction");
|
||||
assert_eq!(prompt_target.default, Some(true));
|
||||
assert_eq!(
|
||||
prompt_target.endpoint.as_ref().unwrap().name,
|
||||
"app_server".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
prompt_target.endpoint.as_ref().unwrap().path,
|
||||
Some("/agent/summary".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
prompt_target.endpoint.as_ref().unwrap().method.as_ref().unwrap().to_string(),
|
||||
"POST".to_string()
|
||||
);
|
||||
|
||||
let error_target = config.error_target.as_ref().unwrap();
|
||||
assert_eq!(
|
||||
error_target.endpoint.as_ref().unwrap().name,
|
||||
"error_target_1".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
error_target.endpoint.as_ref().unwrap().path,
|
||||
Some("/error".to_string())
|
||||
);
|
||||
|
||||
let tracing = config.tracing.as_ref().unwrap();
|
||||
assert_eq!(*tracing, 100);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue