Parse katanemo config using serde/yaml package (#6)

* Parse katanemo config using serde/yaml package

- load yaml file into typed classes
- pass katanemo config to plugin using envoy wasm plugin config
- add tests in configuration.rs file
This commit is contained in:
Adil Hafeez 2024-07-16 14:50:32 -07:00 committed by GitHub
parent d741fdc2de
commit a386d68b41
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 292 additions and 19 deletions

95
envoyfilter/Cargo.lock generated
View file

@ -20,6 +20,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "equivalent"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "hashbrown"
version = "0.13.2"
@ -29,14 +35,39 @@ dependencies = [
"ahash",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "indexmap"
version = "2.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26"
dependencies = [
"equivalent",
"hashbrown 0.14.5",
]
[[package]]
name = "intelligent-prompt-gateway"
version = "0.1.0"
dependencies = [
"log",
"proxy-wasm",
"serde",
"serde_json",
"serde_yaml",
]
[[package]]
name = "itoa"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
[[package]]
name = "log"
version = "0.4.22"
@ -64,7 +95,7 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "823b744520cd4a54ba7ebacbffe4562e839d6dcd8f89209f96a1ace4f5229cd4"
dependencies = [
"hashbrown",
"hashbrown 0.13.2",
"log",
]
@ -78,10 +109,60 @@ dependencies = [
]
[[package]]
name = "syn"
version = "2.0.70"
name = "ryu"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "serde"
version = "1.0.204"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.204"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.120"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
dependencies = [
"indexmap",
"itoa",
"ryu",
"serde",
"unsafe-libyaml",
]
[[package]]
name = "syn"
version = "2.0.71"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b146dcf730474b4bcd16c311627b31ede9ab149045db4d6088b3becaea046462"
dependencies = [
"proc-macro2",
"quote",
@ -94,6 +175,12 @@ version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]]
name = "version_check"
version = "0.9.4"

View file

@ -10,3 +10,6 @@ crate-type = ["cdylib"]
[dependencies]
proxy-wasm = "0.2.1"
log = "0.4"
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9.34"
serde_json = "1.0"

View file

@ -4,10 +4,12 @@ services:
hostname: envoy
ports:
- "10000:10000"
- "19901:9901"
volumes:
- ./envoy.yaml:/etc/envoy/envoy.yaml
- ./target/wasm32-wasi/release:/etc/envoy/proxy-wasm-plugins
networks:
- envoymesh
networks:
envoymesh: {}

View file

@ -1,3 +1,6 @@
admin:
address:
socket_address: { address: 0.0.0.0, port_value: 9901 }
static_resources:
listeners:
address:
@ -38,7 +41,40 @@ static_resources:
name: "http_config"
configuration:
"@type": "type.googleapis.com/google.protobuf.StringValue"
value: katanemo filter
value: |
katanemo-prompt-config:
default-prompt-endpoint: "127.0.0.1"
load-balancing: "round-robin"
timeout-ms: 5000
embedding-provider:
name: "SentenceTransformer"
model: "all-MiniLM-L6-v2"
llm-providers:
- name: "open-ai-gpt-4"
api-key: "$OPEN_AI_API_KEY"
model: gpt-4
system-prompt: |
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
prompt-targets:
- type: context-resolver
name: weather-forecast
few-shot-examples:
- what is the weather in New York?
endpoint: "POST:$WEATHER_FORECAST_API_ENDPOINT"
cache-response: true
cache-response-settings:
- cache-ttl-secs: 3600 # cache expiry in seconds
- cache-max-size: 1000 # in number of items
- cache-eviction-strategy: LRU
vm_config:
runtime: "envoy.wasm.runtime.v8"
code:
@ -47,7 +83,6 @@ static_resources:
- name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
clusters:
- name: httpbin
connect_timeout: 5s

View file

@ -0,0 +1,32 @@
katanemo-prompt-config:
default-prompt-endpoint: "127.0.0.1"
load-balancing: "round-robin"
timeout-ms: 5000
embedding-provider:
name: "SentenceTransformer"
model: "all-MiniLM-L6-v2"
llm-providers:
- name: "open-ai-gpt-4"
api-key: "$OPEN_AI_API_KEY"
model: gpt-4
system-prompt: |
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
prompt-targets:
- type: context-resolver
name: weather-forecast
few-shot-examples:
- what is the weather in New York?
endpoint: "POST:$WEATHER_FORECAST_API_ENDPOINT"
cache-response: true
cache-response-settings:
- cache-ttl-secs: 3600 # cache expiry in seconds
- cache-max-size: 1000 # in number of items
- cache-eviction-strategy: LRU

View file

@ -0,0 +1,103 @@
use serde::{Deserialize, Serialize};
//TODO: possibly use protbuf to enforce schema
//FIX: it is unnecessary to place yaml config inside katanemo-prompt-config
//GH Issue: https://github.com/katanemo/intelligent-prompt-gateway/issues/7
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct Configuration {
#[serde(rename = "katanemo-prompt-config")]
pub prompt_config: PromptConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LoadBalancing {
#[serde(rename = "round-robin")]
RoundRobin,
#[serde(rename = "random")]
Random,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct PromptConfig {
pub default_prompt_endpoint: String,
pub load_balancing: LoadBalancing,
pub timeout_ms: u64,
pub embedding_provider: EmbeddingProviver,
pub llm_providers: Vec<LlmProvider>,
pub system_prompt: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
//TODO: use enum for model, but if there is a new model, we need to update the code
pub struct EmbeddingProviver {
pub name: String,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
//TODO: use enum for model, but if there is a new model, we need to update the code
pub struct LlmProvider {
pub name: String,
pub api_key: String,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct PromptTarget {
#[serde(rename = "type")]
pub prompt_type: String,
pub name: String,
pub few_shot_examples: Vec<String>,
pub endpoint: String,
}
#[cfg(test)]
mod test {
pub const CONFIGURATION: &str = r#"
katanemo-prompt-config:
default-prompt-endpoint: "127.0.0.1"
load-balancing: "round-robin"
timeout-ms: 5000
embedding-provider:
name: "SentenceTransformer"
model: "all-MiniLM-L6-v2"
llm-providers:
- name: "open-ai-gpt-4"
api-key: "$OPEN_AI_API_KEY"
model: gpt-4
system-prompt: |
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
prompt-targets:
- type: context-resolver
name: weather-forecast
few-shot-examples:
- what is the weather in New York?
endpoint: "POST:$WEATHER_FORECAST_API_ENDPOINT"
cache-response: true
cache-response-settings:
- cache-ttl-secs: 3600 # cache expiry in seconds
- cache-max-size: 1000 # in number of items
- cache-eviction-strategy: LRU
"#;
#[test]
fn test_deserialize_configuration() {
let _: super::Configuration = serde_yaml::from_str(CONFIGURATION).unwrap();
}
}

View file

@ -1,3 +1,5 @@
mod configuration;
use log::info;
use stats::IncrementingMetric;
use stats::Metric;
@ -13,19 +15,19 @@ proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(HttpHeaderRoot {
header_content: String::new(),
metrics: WasmMetrics {
config: None,
metrics: WasmMetrics {
counter: stats::Counter::new(String::from("wasm_counter")),
gauge: stats::Gauge::new(String::from("wasm_gauge")),
histogram: stats::Histogram::new(String::from("wasm_histogram")),
}
},
})
});
}}
struct HttpHeader {
context_id: u32,
header_content: String,
config: configuration::Configuration,
metrics: WasmMetrics,
}
@ -34,6 +36,8 @@ impl HttpContext for HttpHeader {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
// Read config
info!("config: {:?}", self.config.prompt_config.system_prompt);
// Metrics
self.metrics.counter.increment(10);
info!("counter -> {}", self.metrics.counter.value());
@ -74,8 +78,7 @@ impl HttpContext for HttpHeader {
}
fn on_http_response_headers(&mut self, _: usize, _: bool) -> Action {
// Note that the filter can add custom headers. In this case the header is coming from a config value.
self.add_http_response_header("custom-header", self.header_content.as_str());
self.set_http_response_header("Powered-By", Some("Katanemo"));
Action::Continue
}
}
@ -113,17 +116,26 @@ struct WasmMetrics {
}
struct HttpHeaderRoot {
header_content: String,
metrics: WasmMetrics,
config: Option<configuration::Configuration>,
}
impl Context for HttpHeaderRoot {}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for HttpHeaderRoot {
fn on_configure(&mut self, _: usize) -> bool {
fn on_configure(&mut self, plugin_configuration_size: usize) -> bool {
info!(
"on_configure: plugin_configuration_size is {}",
plugin_configuration_size
);
if let Some(config_bytes) = self.get_plugin_configuration() {
self.header_content = String::from_utf8(config_bytes).unwrap()
let config_str = String::from_utf8(config_bytes).unwrap();
info!("on_configure: plugin configuration is {:?}", config_str);
self.config = serde_yaml::from_str(&config_str).unwrap();
info!("on_configure: plugin configuration loaded");
info!("on_configure: {:?}", self.config);
}
true
}
@ -131,7 +143,7 @@ impl RootContext for HttpHeaderRoot {
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
Some(Box::new(HttpHeader {
context_id,
header_content: self.header_content.clone(),
config: self.config.clone()?,
metrics: self.metrics,
}))
}

View file

@ -1,4 +1,3 @@
use proxy_wasm::hostcalls;
use proxy_wasm::types::*;
@ -17,7 +16,7 @@ pub trait Metric {
pub trait IncrementingMetric: Metric {
fn increment(&self, offset: i64) {
match hostcalls::increment_metric(self.id(), offset) {
Ok(_) => return,
Ok(data) => data,
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
Err(err) => panic!("unexpected status: {:?}", err),
}
@ -27,7 +26,7 @@ pub trait IncrementingMetric: Metric {
pub trait RecordingMetric: Metric {
fn record(&self, value: u64) {
match hostcalls::record_metric(self.id(), value) {
Ok(_) => return,
Ok(data) => data,
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
Err(err) => panic!("unexpected status: {:?}", err),
}