diff --git a/envoyfilter/Cargo.lock b/envoyfilter/Cargo.lock index 59910a0f..88230623 100644 --- a/envoyfilter/Cargo.lock +++ b/envoyfilter/Cargo.lock @@ -398,6 +398,19 @@ dependencies = [ "typenum", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if 1.0.0", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "debugid" version = "0.8.0" @@ -559,6 +572,17 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + [[package]] name = "futures-sink" version = "0.3.30" @@ -571,6 +595,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.30" @@ -580,6 +610,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -648,6 +679,26 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" +[[package]] +name = "governor" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b" +dependencies = [ + "cfg-if 1.0.0", + "dashmap", + "futures", + "futures-timer", + "no-std-compat", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand", + "smallvec", + "spinning_top", +] + [[package]] name = "h2" version = "0.4.5" @@ -860,6 +911,7 @@ dependencies = [ name = "intelligent-prompt-gateway" version = "0.1.0" dependencies = [ + "governor", "http", "log", "md5", @@ -1080,6 +1132,18 @@ dependencies = [ "tempfile", ] +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "object" version = "0.33.0" @@ -1235,6 +1299,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +[[package]] +name = "portable-atomic" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1307,6 +1377,21 @@ dependencies = [ "cc", ] +[[package]] +name = "quanta" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quote" version = "1.0.36" @@ -1346,6 +1431,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "raw-cpuid" +version = "11.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb9ee317cfe3fbd54b36a511efc1edd42e216903c9cd575e686dd68a2ba90d8d" +dependencies = [ + "bitflags 2.6.0", +] + [[package]] name = "rayon" version = "1.10.0" @@ -1726,6 +1820,15 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "sptr" version = "0.3.2" diff --git a/envoyfilter/Cargo.toml b/envoyfilter/Cargo.toml index e3dff7c1..854e37a1 100644 --- a/envoyfilter/Cargo.toml +++ b/envoyfilter/Cargo.toml @@ -16,6 +16,7 @@ serde_json = "1.0" md5 = "0.7.0" open-message-format-embeddings = { path = "../open-message-format/clients/omf-embeddings-rust" } http = "1.1.0" +governor = "0.6.3" [dev-dependencies] proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" } diff --git a/envoyfilter/src/configuration.rs b/envoyfilter/src/configuration.rs index 03df69f9..a888d56b 100644 --- a/envoyfilter/src/configuration.rs +++ b/envoyfilter/src/configuration.rs @@ -14,27 +14,31 @@ pub struct Configuration { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Ratelimit { - provider: String, - selectors: Vec
, - limit: Limit, + pub provider: String, + pub selector: Header, + pub limit: Limit, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Limit { - tokens: u32, - unit: TimeUnit, + pub tokens: u32, + pub unit: TimeUnit, } #[derive(Debug, Clone, Serialize, Deserialize)] pub enum TimeUnit { + #[serde(rename = "second")] + Second, #[serde(rename = "minute")] Minute, + #[serde(rename = "hour")] + Hour, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct Header { - key: String, - value: Option, + pub key: String, + pub value: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -139,8 +143,8 @@ prompt_targets: ratelimits: - provider: open-ai-gpt-4 - selectors: - - key: x-katanemo-openai-limit-id + selector: + key: x-katanemo-openai-limit-id limit: tokens: 100 unit: minute diff --git a/envoyfilter/src/filter_context.rs b/envoyfilter/src/filter_context.rs index 70220359..d0941b50 100644 --- a/envoyfilter/src/filter_context.rs +++ b/envoyfilter/src/filter_context.rs @@ -3,6 +3,7 @@ use crate::common_types::{ }; use crate::configuration::{Configuration, PromptTarget}; use crate::consts::DEFAULT_EMBEDDING_MODEL; +use crate::ratelimit; use crate::stats::{Gauge, RecordingMetric}; use crate::stream_context::StreamContext; use log::info; @@ -207,9 +208,9 @@ impl FilterContext { { panic!("duplicate token_id") } - // self.metrics - // .active_http_calls - // .record(self.callouts.len().try_into().unwrap()); + self.metrics + .active_http_calls + .record(self.callouts.len().try_into().unwrap()); } } @@ -257,6 +258,14 @@ impl RootContext for FilterContext { fn on_configure(&mut self, _: usize) -> bool { if let Some(config_bytes) = self.get_plugin_configuration() { self.config = serde_yaml::from_slice(&config_bytes).unwrap(); + + if let Some(ratelimits_config) = self + .config + .as_mut() + .and_then(|config| config.ratelimits.as_mut()) + { + ratelimit::ratelimits(Some(std::mem::take(ratelimits_config))); + } } true } @@ -278,7 +287,6 @@ impl RootContext for FilterContext { } fn on_tick(&mut self) { - // initialize vector store self.init_vector_store(); self.process_prompt_targets(); self.set_tick_period(Duration::from_secs(0)); diff --git a/envoyfilter/src/lib.rs b/envoyfilter/src/lib.rs index f62c6367..5b7b9466 100644 --- a/envoyfilter/src/lib.rs +++ b/envoyfilter/src/lib.rs @@ -6,6 +6,7 @@ mod common_types; mod configuration; mod consts; mod filter_context; +mod ratelimit; mod stats; mod stream_context; diff --git a/envoyfilter/src/ratelimit.rs b/envoyfilter/src/ratelimit.rs new file mode 100644 index 00000000..3a147e9f --- /dev/null +++ b/envoyfilter/src/ratelimit.rs @@ -0,0 +1,419 @@ +use crate::configuration; +use crate::configuration::{Limit, Ratelimit, TimeUnit}; +use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota}; +use std::num::{NonZero, NonZeroU32}; +use std::sync::RwLock; +use std::{collections::HashMap, sync::OnceLock}; + +pub type RatelimitData = RwLock; + +pub fn ratelimits(ratelimits_config: Option>) -> &'static RatelimitData { + static RATELIMIT_DATA: OnceLock = OnceLock::new(); + RATELIMIT_DATA.get_or_init(|| { + RwLock::new(RatelimitMap::new( + ratelimits_config.expect("The initialization call has to have passed a config"), + )) + }) +} + +// The Data Structure is laid out in the following way: +// Provider -> Hash { Header -> Limit }. +// If the Header used to configure the given Limit: +// a) Has None value, then there will be N Limit keyed by the Header value. +// b) Has Some() value, then there will be 1 Limit keyed by the empty string. +// It would have been nicer to use a non-keyed limit for b). However, the type system made that option a nightmare. +pub struct RatelimitMap { + datastore: HashMap>>, +} + +// This version of Header demands that the user passes a header value to match on. +#[allow(unused)] +pub struct Header { + key: String, + value: String, +} + +impl Header { + fn into_config(self) -> configuration::Header { + configuration::Header { + key: self.key, + value: Some(self.value), + } + } +} + +impl RatelimitMap { + // n.b new is private so that the only access to the Ratelimits can be done via the static + // reference inside a RwLock via ratelimit::ratelimits(). + fn new(ratelimits_config: Vec) -> Self { + let mut new_ratelimit_map = RatelimitMap { + datastore: HashMap::new(), + }; + for ratelimit_config in ratelimits_config { + let limit = DefaultKeyedRateLimiter::keyed(get_quota(ratelimit_config.limit)); + + match new_ratelimit_map + .datastore + .get_mut(&ratelimit_config.provider) + { + Some(limits) => match limits.get_mut(&ratelimit_config.selector) { + Some(_) => { + panic!("repeated selector. Selectors per provider must be unique") + } + None => { + limits.insert(ratelimit_config.selector, limit); + } + }, + None => { + // The provider has not been seen before. + // Insert the provider and a new HashMap with the specified limit + let new_hash_map = HashMap::from([(ratelimit_config.selector, limit)]); + new_ratelimit_map + .datastore + .insert(ratelimit_config.provider, new_hash_map); + } + } + } + new_ratelimit_map + } + + #[allow(unused)] + pub fn check_limit( + &self, + provider: String, + selector: Header, + tokens_used: NonZeroU32, + ) -> Result<(), String> { + let provider_limits = match self.datastore.get(&provider) { + None => { + // No limit configured for this provider, hence ok. + return Ok(()); + } + Some(limit) => limit, + }; + + let mut config_selector = selector.into_config(); + + let (limit, limit_key) = match provider_limits.get(&config_selector) { + // This is a specific limit, i.e one that was configured with both key, and value. + // Therefore, the key for the internal limit does not matter, and hence the empty string is always returned. + Some(limit) => (limit, String::from("")), + None => { + // Unwrap is ok here because we _know_ the value exists. + let header_key = config_selector.value.take().unwrap(); + // Search for less specific limit, i.e, one that was configured without a value, therefore every Header + // value has its own key in the internal limit. + match provider_limits.get(&config_selector) { + Some(limit) => (limit, header_key), + // No limit for that header key, value pair exists within that provider limits. + None => { + return Ok(()); + } + } + } + }; + + match limit.check_key_n(&limit_key, tokens_used) { + Ok(Ok(())) => Ok(()), + Ok(Err(_)) => Err(String::from("Not allowed")), + Err(InsufficientCapacity(_)) => Err(String::from("Not allowed")), + } + } +} + +fn get_quota(limit: Limit) -> Quota { + let tokens = NonZero::new(limit.tokens).expect("Limit's tokens must be positive"); + match limit.unit { + TimeUnit::Second => Quota::per_second(tokens), + TimeUnit::Minute => Quota::per_minute(tokens), + TimeUnit::Hour => Quota::per_hour(tokens), + } +} + +// The following tests are inside the ratelimit module in order to access RatelimitMap::new() in order to provide +// different configuration values per test. +#[test] +fn non_existent_provider_is_ok() { + let ratelimits_config = vec![Ratelimit { + provider: String::from("provider"), + selector: configuration::Header { + key: String::from("only-key"), + value: None, + }, + limit: Limit { + tokens: 100, + unit: TimeUnit::Minute, + }, + }]; + + let ratelimits = RatelimitMap::new(ratelimits_config); + + assert!(ratelimits + .check_limit( + String::from("non-existent-provider"), + Header { + key: String::from("key"), + value: String::from("value"), + }, + NonZero::new(5000).unwrap(), + ) + .is_ok()) +} + +#[test] +fn non_existent_key_is_ok() { + let ratelimits_config = vec![Ratelimit { + provider: String::from("provider"), + selector: configuration::Header { + key: String::from("only-key"), + value: None, + }, + limit: Limit { + tokens: 100, + unit: TimeUnit::Minute, + }, + }]; + + let ratelimits = RatelimitMap::new(ratelimits_config); + + assert!(ratelimits + .check_limit( + String::from("provider"), + Header { + key: String::from("key"), + value: String::from("value"), + }, + NonZero::new(5000).unwrap(), + ) + .is_ok()) +} + +#[test] +fn specific_limit_does_not_catch_non_specific_value() { + let ratelimits_config = vec![Ratelimit { + provider: String::from("provider"), + selector: configuration::Header { + key: String::from("key"), + value: Some(String::from("value")), + }, + limit: Limit { + tokens: 200, + unit: TimeUnit::Second, + }, + }]; + + let ratelimits = RatelimitMap::new(ratelimits_config); + + assert!(ratelimits + .check_limit( + String::from("provider"), + Header { + key: String::from("key"), + value: String::from("not-the-correct-value"), + }, + NonZero::new(5000).unwrap(), + ) + .is_ok()) +} + +#[test] +fn specific_limit_is_hit() { + let ratelimits_config = vec![Ratelimit { + provider: String::from("provider"), + selector: configuration::Header { + key: String::from("key"), + value: Some(String::from("value")), + }, + limit: Limit { + tokens: 200, + unit: TimeUnit::Hour, + }, + }]; + + let ratelimits = RatelimitMap::new(ratelimits_config); + + assert!(ratelimits + .check_limit( + String::from("provider"), + Header { + key: String::from("key"), + value: String::from("value"), + }, + NonZero::new(5000).unwrap(), + ) + .is_err()) +} + +#[test] +fn non_specific_key_has_different_limits_for_different_values() { + let ratelimits_config = vec![Ratelimit { + provider: String::from("provider"), + selector: configuration::Header { + key: String::from("only-key"), + value: None, + }, + limit: Limit { + tokens: 100, + unit: TimeUnit::Hour, + }, + }]; + + let ratelimits = RatelimitMap::new(ratelimits_config); + + // Value1 takes 50. + assert!(ratelimits + .check_limit( + String::from("provider"), + Header { + key: String::from("only-key"), + value: String::from("value1"), + }, + NonZero::new(50).unwrap(), + ) + .is_ok()); + + // value2 takes 60 because it has its own 100 limit + assert!(ratelimits + .check_limit( + String::from("provider"), + Header { + key: String::from("only-key"), + value: String::from("value2"), + }, + NonZero::new(60).unwrap(), + ) + .is_ok()); + + // However value1 cannot take more than 100 per hour which 50+70 = 120 + assert!(ratelimits + .check_limit( + String::from("provider"), + Header { + key: String::from("only-key"), + value: String::from("value1"), + }, + NonZero::new(70).unwrap(), + ) + .is_err()) +} + +#[test] +fn different_provider_can_have_different_limits_with_the_same_keys() { + let ratelimits_config = vec![ + Ratelimit { + provider: String::from("first_provider"), + selector: configuration::Header { + key: String::from("key"), + value: Some(String::from("value")), + }, + limit: Limit { + tokens: 100, + unit: TimeUnit::Hour, + }, + }, + Ratelimit { + provider: String::from("second_provider"), + selector: configuration::Header { + key: String::from("key"), + value: Some(String::from("value")), + }, + limit: Limit { + tokens: 200, + unit: TimeUnit::Hour, + }, + }, + ]; + + let ratelimits = RatelimitMap::new(ratelimits_config); + + assert!(ratelimits + .check_limit( + String::from("first_provider"), + Header { + key: String::from("key"), + value: String::from("value"), + }, + NonZero::new(100).unwrap(), + ) + .is_ok()); + + assert!(ratelimits + .check_limit( + String::from("second_provider"), + Header { + key: String::from("key"), + value: String::from("value"), + }, + NonZero::new(200).unwrap(), + ) + .is_ok()); + + assert!(ratelimits + .check_limit( + String::from("first_provider"), + Header { + key: String::from("key"), + value: String::from("value"), + }, + NonZero::new(1).unwrap(), + ) + .is_err()); + + assert!(ratelimits + .check_limit( + String::from("second_provider"), + Header { + key: String::from("key"), + value: String::from("value"), + }, + NonZero::new(1).unwrap(), + ) + .is_err()); +} + +// These tests use the publicly exposed static singleton, thus the same configuration is used in every test. +// If more tests are written here, move the initial call out of the test. +#[cfg(test)] +mod test { + use super::ratelimits; + use crate::configuration; + use configuration::{Limit, Ratelimit, TimeUnit}; + use std::num::NonZero; + use std::thread; + + #[test] + fn different_threads_have_same_ratelimit_data_structure() { + let ratelimits_config = Some(vec![Ratelimit { + provider: String::from("provider"), + selector: configuration::Header { + key: String::from("key"), + value: Some(String::from("value")), + }, + limit: Limit { + tokens: 200, + unit: TimeUnit::Hour, + }, + }]); + + // Initialize in the main thread. + ratelimits(ratelimits_config); + + // Use the singleton in a different thread. + thread::spawn(|| { + let ratelimits = ratelimits(None); + + assert!(ratelimits + .read() + .unwrap() + .check_limit( + String::from("provider"), + super::Header { + key: String::from("key"), + value: String::from("value"), + }, + NonZero::new(5000).unwrap(), + ) + .is_err()) + }); + } +}