mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 04:42:49 +02:00
Add Ratelimit on request tokens (#44)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
d98517f240
commit
dd48689aee
10 changed files with 698 additions and 200 deletions
|
|
@ -3,5 +3,6 @@ pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store";
|
|||
pub const DEFAULT_NER_MODEL: &str = "urchade/gliner_large-v2.1";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6;
|
||||
pub const DEFAULT_NER_THRESHOLD: f64 = 0.6;
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector";
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
use crate::consts::DEFAULT_EMBEDDING_MODEL;
|
||||
use crate::ratelimit;
|
||||
use crate::stats::{Gauge, RecordingMetric};
|
||||
use crate::stats::{Counter, Gauge, RecordingMetric};
|
||||
use crate::stream_context::StreamContext;
|
||||
use log::info;
|
||||
use log::{debug, info};
|
||||
use md5::Digest;
|
||||
use open_message_format_embeddings::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
|
|
@ -15,23 +15,26 @@ use public_types::common_types::{
|
|||
use public_types::configuration::{Configuration, PromptTarget};
|
||||
use serde_json::to_string;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct WasmMetrics {
|
||||
active_http_calls: Gauge,
|
||||
pub struct WasmMetrics {
|
||||
pub active_http_calls: Gauge,
|
||||
pub ratelimited_rq: Counter,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
WasmMetrics {
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
ratelimited_rq: Counter::new(String::from("ratelimited_rq")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FilterContext {
|
||||
metrics: WasmMetrics,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
config: Option<Configuration>,
|
||||
|
|
@ -42,7 +45,7 @@ impl FilterContext {
|
|||
FilterContext {
|
||||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
metrics: WasmMetrics::new(),
|
||||
metrics: Rc::new(WasmMetrics::new()),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -259,6 +262,8 @@ impl RootContext for FilterContext {
|
|||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
|
||||
|
||||
debug!("set configuration object: {:?}", self.config);
|
||||
|
||||
if let Some(ratelimits_config) = self
|
||||
.config
|
||||
.as_mut()
|
||||
|
|
@ -273,7 +278,9 @@ impl RootContext for FilterContext {
|
|||
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||
Some(Box::new(StreamContext {
|
||||
host_header: None,
|
||||
ratelimit_selector: None,
|
||||
callouts: HashMap::new(),
|
||||
metrics: Rc::clone(&self.metrics),
|
||||
}))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ mod filter_context;
|
|||
mod ratelimit;
|
||||
mod stats;
|
||||
mod stream_context;
|
||||
mod tokenizer;
|
||||
|
||||
proxy_wasm::main! {{
|
||||
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
|
||||
use log::debug;
|
||||
use public_types::configuration;
|
||||
use public_types::configuration::{Limit, Ratelimit, TimeUnit};
|
||||
use std::num::{NonZero, NonZeroU32};
|
||||
|
|
@ -28,9 +29,10 @@ pub struct RatelimitMap {
|
|||
|
||||
// This version of Header demands that the user passes a header value to match on.
|
||||
#[allow(unused)]
|
||||
#[derive(Debug)]
|
||||
pub struct Header {
|
||||
key: String,
|
||||
value: String,
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
impl Header {
|
||||
|
|
@ -84,6 +86,11 @@ impl RatelimitMap {
|
|||
selector: Header,
|
||||
tokens_used: NonZeroU32,
|
||||
) -> Result<(), String> {
|
||||
debug!(
|
||||
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
|
||||
provider, selector, tokens_used
|
||||
);
|
||||
|
||||
let provider_limits = match self.datastore.get(&provider) {
|
||||
None => {
|
||||
// No limit configured for this provider, hence ok.
|
||||
|
|
|
|||
|
|
@ -74,7 +74,10 @@ impl Metric for Gauge {
|
|||
}
|
||||
}
|
||||
|
||||
/// For state of the world updates
|
||||
impl RecordingMetric for Gauge {}
|
||||
/// For offset deltas
|
||||
impl IncrementingMetric for Gauge {}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct Histogram {
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
use crate::consts::{
|
||||
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::WasmMetrics;
|
||||
use crate::ratelimit;
|
||||
use crate::ratelimit::Header;
|
||||
use crate::stats::IncrementingMetric;
|
||||
use crate::tokenizer;
|
||||
use http::StatusCode;
|
||||
use log::error;
|
||||
use log::info;
|
||||
use log::warn;
|
||||
use log::{debug, error, info, warn};
|
||||
use open_message_format_embeddings::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
|
|
@ -17,6 +20,8 @@ use public_types::common_types::{
|
|||
};
|
||||
use public_types::configuration::{Entity, PromptTarget};
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
use std::time::Duration;
|
||||
|
||||
enum RequestType {
|
||||
|
|
@ -35,7 +40,9 @@ pub struct CallContext {
|
|||
|
||||
pub struct StreamContext {
|
||||
pub host_header: Option<String>,
|
||||
pub ratelimit_selector: Option<Header>,
|
||||
pub callouts: HashMap<u32, CallContext>,
|
||||
pub metrics: Rc<WasmMetrics>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -65,6 +72,15 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn save_ratelimit_header(&mut self) {
|
||||
self.ratelimit_selector = self
|
||||
.get_http_request_header(RATELIMIT_SELECTOR_HEADER_KEY)
|
||||
.and_then(|key| {
|
||||
self.get_http_request_header(&key)
|
||||
.map(|value| Header { key, value })
|
||||
});
|
||||
}
|
||||
|
||||
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
|
||||
Ok(embedding_response) => embedding_response,
|
||||
|
|
@ -115,6 +131,7 @@ impl StreamContext {
|
|||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
}
|
||||
|
||||
fn search_points_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
|
|
@ -202,6 +219,7 @@ impl StreamContext {
|
|||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
}
|
||||
|
||||
fn ner_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
|
|
@ -290,10 +308,11 @@ impl StreamContext {
|
|||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
}
|
||||
|
||||
fn context_resolver_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
|
||||
info!("response received for context_resolver");
|
||||
debug!("response received for context_resolver");
|
||||
let body_string = String::from_utf8(body);
|
||||
let prompt_target = callout_context.prompt_target.unwrap();
|
||||
let mut request_body = callout_context.request_body;
|
||||
|
|
@ -331,7 +350,30 @@ impl StreamContext {
|
|||
return;
|
||||
}
|
||||
};
|
||||
info!("sending request to openai: msg {}", json_string);
|
||||
|
||||
// Tokenize and Ratelimit.
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
if let Ok(token_count) = tokenizer::token_count(&request_body.model, &json_string) {
|
||||
match ratelimit::ratelimits(None).read().unwrap().check_limit(
|
||||
request_body.model,
|
||||
selector,
|
||||
NonZero::new(token_count as u32).unwrap(),
|
||||
) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
self.send_http_response(
|
||||
StatusCode::TOO_MANY_REQUESTS.as_u16().into(),
|
||||
vec![],
|
||||
Some(format!("Exceeded Ratelimit: {}", err).as_bytes()),
|
||||
);
|
||||
self.metrics.ratelimited_rq.increment(1);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("sending request to openai: msg {}", json_string);
|
||||
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
|
|
@ -345,6 +387,7 @@ impl HttpContext for StreamContext {
|
|||
self.save_host_header();
|
||||
self.delete_content_length_header();
|
||||
self.modify_path_header();
|
||||
self.save_ratelimit_header();
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
|
|
@ -450,6 +493,7 @@ impl HttpContext for StreamContext {
|
|||
token_id
|
||||
)
|
||||
}
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
|
||||
Action::Pause
|
||||
}
|
||||
|
|
@ -464,6 +508,7 @@ impl Context for StreamContext {
|
|||
_num_trailers: usize,
|
||||
) {
|
||||
let callout_context = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
self.metrics.active_http_calls.increment(-1);
|
||||
|
||||
let resp = self.get_http_call_response_body(0, body_size);
|
||||
|
||||
|
|
|
|||
39
envoyfilter/src/tokenizer.rs
Normal file
39
envoyfilter/src/tokenizer.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
use log::debug;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
#[allow(dead_code)]
|
||||
pub enum Error {
|
||||
UnknownModel,
|
||||
FailedToTokenize,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
|
||||
debug!("getting token count model={}", model_name);
|
||||
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel)?;
|
||||
Ok(bpe.encode_ordinary(text).len())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn encode_ordinary() {
|
||||
let model_name = "gpt-3.5-turbo";
|
||||
let text = "How many tokens does this sentence have?";
|
||||
assert_eq!(
|
||||
8,
|
||||
token_count(model_name, text).expect("correct tokenization")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unrecognized_model() {
|
||||
assert_eq!(
|
||||
Error::UnknownModel,
|
||||
token_count("unknown", "").expect_err("unknown model")
|
||||
)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue