Add Ratelimit on request tokens (#44)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-09-04 17:28:12 -07:00 committed by GitHub
parent d98517f240
commit dd48689aee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 698 additions and 200 deletions

View file

@ -3,5 +3,6 @@ pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store";
pub const DEFAULT_NER_MODEL: &str = "urchade/gliner_large-v2.1";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6;
pub const DEFAULT_NER_THRESHOLD: f64 = 0.6;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector";
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";

View file

@ -1,8 +1,8 @@
use crate::consts::DEFAULT_EMBEDDING_MODEL;
use crate::ratelimit;
use crate::stats::{Gauge, RecordingMetric};
use crate::stats::{Counter, Gauge, RecordingMetric};
use crate::stream_context::StreamContext;
use log::info;
use log::{debug, info};
use md5::Digest;
use open_message_format_embeddings::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -15,23 +15,26 @@ use public_types::common_types::{
use public_types::configuration::{Configuration, PromptTarget};
use serde_json::to_string;
use std::collections::HashMap;
use std::rc::Rc;
use std::time::Duration;
#[derive(Copy, Clone)]
struct WasmMetrics {
active_http_calls: Gauge,
pub struct WasmMetrics {
pub active_http_calls: Gauge,
pub ratelimited_rq: Counter,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
WasmMetrics {
active_http_calls: Gauge::new(String::from("active_http_calls")),
ratelimited_rq: Counter::new(String::from("ratelimited_rq")),
}
}
}
pub struct FilterContext {
metrics: WasmMetrics,
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, CallContext>,
config: Option<Configuration>,
@ -42,7 +45,7 @@ impl FilterContext {
FilterContext {
callouts: HashMap::new(),
config: None,
metrics: WasmMetrics::new(),
metrics: Rc::new(WasmMetrics::new()),
}
}
@ -259,6 +262,8 @@ impl RootContext for FilterContext {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
debug!("set configuration object: {:?}", self.config);
if let Some(ratelimits_config) = self
.config
.as_mut()
@ -273,7 +278,9 @@ impl RootContext for FilterContext {
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
Some(Box::new(StreamContext {
host_header: None,
ratelimit_selector: None,
callouts: HashMap::new(),
metrics: Rc::clone(&self.metrics),
}))
}

View file

@ -7,6 +7,7 @@ mod filter_context;
mod ratelimit;
mod stats;
mod stream_context;
mod tokenizer;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);

View file

@ -1,4 +1,5 @@
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
use log::debug;
use public_types::configuration;
use public_types::configuration::{Limit, Ratelimit, TimeUnit};
use std::num::{NonZero, NonZeroU32};
@ -28,9 +29,10 @@ pub struct RatelimitMap {
// This version of Header demands that the user passes a header value to match on.
#[allow(unused)]
#[derive(Debug)]
pub struct Header {
key: String,
value: String,
pub key: String,
pub value: String,
}
impl Header {
@ -84,6 +86,11 @@ impl RatelimitMap {
selector: Header,
tokens_used: NonZeroU32,
) -> Result<(), String> {
debug!(
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
provider, selector, tokens_used
);
let provider_limits = match self.datastore.get(&provider) {
None => {
// No limit configured for this provider, hence ok.

View file

@ -74,7 +74,10 @@ impl Metric for Gauge {
}
}
/// For state of the world updates
impl RecordingMetric for Gauge {}
/// For offset deltas
impl IncrementingMetric for Gauge {}
#[derive(Copy, Clone)]
pub struct Histogram {

View file

@ -1,11 +1,14 @@
use crate::consts::{
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE,
DEFAULT_PROMPT_TARGET_THRESHOLD, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::WasmMetrics;
use crate::ratelimit;
use crate::ratelimit::Header;
use crate::stats::IncrementingMetric;
use crate::tokenizer;
use http::StatusCode;
use log::error;
use log::info;
use log::warn;
use log::{debug, error, info, warn};
use open_message_format_embeddings::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
@ -17,6 +20,8 @@ use public_types::common_types::{
};
use public_types::configuration::{Entity, PromptTarget};
use std::collections::HashMap;
use std::num::NonZero;
use std::rc::Rc;
use std::time::Duration;
enum RequestType {
@ -35,7 +40,9 @@ pub struct CallContext {
pub struct StreamContext {
pub host_header: Option<String>,
pub ratelimit_selector: Option<Header>,
pub callouts: HashMap<u32, CallContext>,
pub metrics: Rc<WasmMetrics>,
}
impl StreamContext {
@ -65,6 +72,15 @@ impl StreamContext {
}
}
fn save_ratelimit_header(&mut self) {
self.ratelimit_selector = self
.get_http_request_header(RATELIMIT_SELECTOR_HEADER_KEY)
.and_then(|key| {
self.get_http_request_header(&key)
.map(|value| Header { key, value })
});
}
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
Ok(embedding_response) => embedding_response,
@ -115,6 +131,7 @@ impl StreamContext {
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
self.metrics.active_http_calls.increment(1);
}
fn search_points_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
@ -202,6 +219,7 @@ impl StreamContext {
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
self.metrics.active_http_calls.increment(1);
}
fn ner_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
@ -290,10 +308,11 @@ impl StreamContext {
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
self.metrics.active_http_calls.increment(1);
}
fn context_resolver_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
info!("response received for context_resolver");
debug!("response received for context_resolver");
let body_string = String::from_utf8(body);
let prompt_target = callout_context.prompt_target.unwrap();
let mut request_body = callout_context.request_body;
@ -331,7 +350,30 @@ impl StreamContext {
return;
}
};
info!("sending request to openai: msg {}", json_string);
// Tokenize and Ratelimit.
if let Some(selector) = self.ratelimit_selector.take() {
if let Ok(token_count) = tokenizer::token_count(&request_body.model, &json_string) {
match ratelimit::ratelimits(None).read().unwrap().check_limit(
request_body.model,
selector,
NonZero::new(token_count as u32).unwrap(),
) {
Ok(_) => (),
Err(err) => {
self.send_http_response(
StatusCode::TOO_MANY_REQUESTS.as_u16().into(),
vec![],
Some(format!("Exceeded Ratelimit: {}", err).as_bytes()),
);
self.metrics.ratelimited_rq.increment(1);
return;
}
}
}
}
debug!("sending request to openai: msg {}", json_string);
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
self.resume_http_request();
}
@ -345,6 +387,7 @@ impl HttpContext for StreamContext {
self.save_host_header();
self.delete_content_length_header();
self.modify_path_header();
self.save_ratelimit_header();
Action::Continue
}
@ -450,6 +493,7 @@ impl HttpContext for StreamContext {
token_id
)
}
self.metrics.active_http_calls.increment(1);
Action::Pause
}
@ -464,6 +508,7 @@ impl Context for StreamContext {
_num_trailers: usize,
) {
let callout_context = self.callouts.remove(&token_id).expect("invalid token_id");
self.metrics.active_http_calls.increment(-1);
let resp = self.get_http_call_response_body(0, body_size);

View file

@ -0,0 +1,39 @@
use log::debug;
#[derive(Debug, PartialEq, Eq)]
#[allow(dead_code)]
pub enum Error {
UnknownModel,
FailedToTokenize,
}
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
debug!("getting token count model={}", model_name);
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel)?;
Ok(bpe.encode_ordinary(text).len())
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn encode_ordinary() {
let model_name = "gpt-3.5-turbo";
let text = "How many tokens does this sentence have?";
assert_eq!(
8,
token_count(model_name, text).expect("correct tokenization")
);
}
#[test]
fn unrecognized_model() {
assert_eq!(
Error::UnknownModel,
token_count("unknown", "").expect_err("unknown model")
)
}
}