Parse katanemo config using serde/yaml package (#6)

* Parse katanemo config using serde/yaml package

- load yaml file into typed classes
- pass katanemo config to plugin using envoy wasm plugin config
- add tests in configuration.rs file
This commit is contained in:
Adil Hafeez 2024-07-16 14:50:32 -07:00 committed by GitHub
parent d741fdc2de
commit a386d68b41
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 292 additions and 19 deletions

View file

@ -0,0 +1,103 @@
use serde::{Deserialize, Serialize};
//TODO: possibly use protbuf to enforce schema
//FIX: it is unnecessary to place yaml config inside katanemo-prompt-config
//GH Issue: https://github.com/katanemo/intelligent-prompt-gateway/issues/7
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct Configuration {
#[serde(rename = "katanemo-prompt-config")]
pub prompt_config: PromptConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LoadBalancing {
#[serde(rename = "round-robin")]
RoundRobin,
#[serde(rename = "random")]
Random,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct PromptConfig {
pub default_prompt_endpoint: String,
pub load_balancing: LoadBalancing,
pub timeout_ms: u64,
pub embedding_provider: EmbeddingProviver,
pub llm_providers: Vec<LlmProvider>,
pub system_prompt: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
//TODO: use enum for model, but if there is a new model, we need to update the code
pub struct EmbeddingProviver {
pub name: String,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
//TODO: use enum for model, but if there is a new model, we need to update the code
pub struct LlmProvider {
pub name: String,
pub api_key: String,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct PromptTarget {
#[serde(rename = "type")]
pub prompt_type: String,
pub name: String,
pub few_shot_examples: Vec<String>,
pub endpoint: String,
}
#[cfg(test)]
mod test {
pub const CONFIGURATION: &str = r#"
katanemo-prompt-config:
default-prompt-endpoint: "127.0.0.1"
load-balancing: "round-robin"
timeout-ms: 5000
embedding-provider:
name: "SentenceTransformer"
model: "all-MiniLM-L6-v2"
llm-providers:
- name: "open-ai-gpt-4"
api-key: "$OPEN_AI_API_KEY"
model: gpt-4
system-prompt: |
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
prompt-targets:
- type: context-resolver
name: weather-forecast
few-shot-examples:
- what is the weather in New York?
endpoint: "POST:$WEATHER_FORECAST_API_ENDPOINT"
cache-response: true
cache-response-settings:
- cache-ttl-secs: 3600 # cache expiry in seconds
- cache-max-size: 1000 # in number of items
- cache-eviction-strategy: LRU
"#;
#[test]
fn test_deserialize_configuration() {
let _: super::Configuration = serde_yaml::from_str(CONFIGURATION).unwrap();
}
}

View file

@ -1,3 +1,5 @@
mod configuration;
use log::info;
use stats::IncrementingMetric;
use stats::Metric;
@ -13,19 +15,19 @@ proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(HttpHeaderRoot {
header_content: String::new(),
metrics: WasmMetrics {
config: None,
metrics: WasmMetrics {
counter: stats::Counter::new(String::from("wasm_counter")),
gauge: stats::Gauge::new(String::from("wasm_gauge")),
histogram: stats::Histogram::new(String::from("wasm_histogram")),
}
},
})
});
}}
struct HttpHeader {
context_id: u32,
header_content: String,
config: configuration::Configuration,
metrics: WasmMetrics,
}
@ -34,6 +36,8 @@ impl HttpContext for HttpHeader {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
// Read config
info!("config: {:?}", self.config.prompt_config.system_prompt);
// Metrics
self.metrics.counter.increment(10);
info!("counter -> {}", self.metrics.counter.value());
@ -74,8 +78,7 @@ impl HttpContext for HttpHeader {
}
fn on_http_response_headers(&mut self, _: usize, _: bool) -> Action {
// Note that the filter can add custom headers. In this case the header is coming from a config value.
self.add_http_response_header("custom-header", self.header_content.as_str());
self.set_http_response_header("Powered-By", Some("Katanemo"));
Action::Continue
}
}
@ -113,17 +116,26 @@ struct WasmMetrics {
}
struct HttpHeaderRoot {
header_content: String,
metrics: WasmMetrics,
config: Option<configuration::Configuration>,
}
impl Context for HttpHeaderRoot {}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for HttpHeaderRoot {
fn on_configure(&mut self, _: usize) -> bool {
fn on_configure(&mut self, plugin_configuration_size: usize) -> bool {
info!(
"on_configure: plugin_configuration_size is {}",
plugin_configuration_size
);
if let Some(config_bytes) = self.get_plugin_configuration() {
self.header_content = String::from_utf8(config_bytes).unwrap()
let config_str = String::from_utf8(config_bytes).unwrap();
info!("on_configure: plugin configuration is {:?}", config_str);
self.config = serde_yaml::from_str(&config_str).unwrap();
info!("on_configure: plugin configuration loaded");
info!("on_configure: {:?}", self.config);
}
true
}
@ -131,7 +143,7 @@ impl RootContext for HttpHeaderRoot {
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
Some(Box::new(HttpHeader {
context_id,
header_content: self.header_content.clone(),
config: self.config.clone()?,
metrics: self.metrics,
}))
}

View file

@ -1,4 +1,3 @@
use proxy_wasm::hostcalls;
use proxy_wasm::types::*;
@ -17,7 +16,7 @@ pub trait Metric {
pub trait IncrementingMetric: Metric {
fn increment(&self, offset: i64) {
match hostcalls::increment_metric(self.id(), offset) {
Ok(_) => return,
Ok(data) => data,
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
Err(err) => panic!("unexpected status: {:?}", err),
}
@ -27,7 +26,7 @@ pub trait IncrementingMetric: Metric {
pub trait RecordingMetric: Metric {
fn record(&self, value: u64) {
match hostcalls::record_metric(self.id(), value) {
Ok(_) => return,
Ok(data) => data,
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
Err(err) => panic!("unexpected status: {:?}", err),
}