mirror of
https://github.com/katanemo/plano.git
synced 2026-05-04 13:23:00 +02:00
rename envoyfilter => arch (#91)
* rename envoyfilter => arch * fix more files * more fixes * more renames
This commit is contained in:
parent
7168b14ed3
commit
ea86f73605
33 changed files with 91 additions and 99 deletions
11
arch/src/consts.rs
Normal file
11
arch/src/consts.rs
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
|
||||
pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
|
||||
pub const ARC_FC_CLUSTER: &str = "arch_fc";
|
||||
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
||||
pub const MODEL_SERVER_NAME: &str = "model_server";
|
||||
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
||||
285
arch/src/filter_context.rs
Normal file
285
arch/src/filter_context.rs
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
use crate::consts::{DEFAULT_EMBEDDING_MODEL, MODEL_SERVER_NAME};
|
||||
use crate::ratelimit;
|
||||
use crate::stats::{Counter, Gauge, RecordingMetric};
|
||||
use crate::stream_context::StreamContext;
|
||||
use log::debug;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::EmbeddingType;
|
||||
use public_types::configuration::{Configuration, Overrides, PromptGuards, PromptTarget};
|
||||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use serde_json::to_string;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{OnceLock, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct WasmMetrics {
|
||||
pub active_http_calls: Gauge,
|
||||
pub ratelimited_rq: Counter,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
WasmMetrics {
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
ratelimited_rq: Counter::new(String::from("ratelimited_rq")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CallContext {
|
||||
prompt_target: String,
|
||||
embedding_type: EmbeddingType,
|
||||
}
|
||||
|
||||
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterContext {
|
||||
metrics: Rc<WasmMetrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
config: Option<Configuration>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
}
|
||||
|
||||
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
|
||||
static EMBEDDINGS: OnceLock<RwLock<HashMap<String, EmbeddingTypeMap>>> = OnceLock::new();
|
||||
EMBEDDINGS.get_or_init(|| {
|
||||
let embeddings: HashMap<String, EmbeddingTypeMap> = HashMap::new();
|
||||
RwLock::new(embeddings)
|
||||
})
|
||||
}
|
||||
|
||||
impl FilterContext {
|
||||
pub fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
metrics: Rc::new(WasmMetrics::new()),
|
||||
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(Some(PromptGuards::default())),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_prompt_targets(&mut self) {
|
||||
let prompt_targets = match self.prompt_targets.read() {
|
||||
Ok(prompt_targets) => prompt_targets,
|
||||
Err(e) => {
|
||||
panic!("Error reading prompt targets: {:?}", e);
|
||||
}
|
||||
};
|
||||
for values in prompt_targets.iter() {
|
||||
let prompt_target = &values.1;
|
||||
|
||||
// schedule embeddings call for prompt target name
|
||||
let token_id = self.schedule_embeddings_call(prompt_target.name.clone());
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext {
|
||||
prompt_target: prompt_target.name.clone(),
|
||||
embedding_type: EmbeddingType::Name,
|
||||
}
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
|
||||
// schedule embeddings call for prompt target description
|
||||
let token_id = self.schedule_embeddings_call(prompt_target.description.clone());
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext {
|
||||
prompt_target: prompt_target.name.clone(),
|
||||
embedding_type: EmbeddingType::Description,
|
||||
}
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
fn schedule_embeddings_call(&self, input: String) -> u32 {
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(input)),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let json_data = to_string(&embeddings_input).unwrap();
|
||||
let token_id = match self.dispatch_http_call(
|
||||
MODEL_SERVER_NAME,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", MODEL_SERVER_NAME),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(60),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
token_id
|
||||
}
|
||||
|
||||
fn embedding_response_handler(
|
||||
&mut self,
|
||||
body_size: usize,
|
||||
embedding_type: EmbeddingType,
|
||||
prompt_target_name: String,
|
||||
) {
|
||||
let prompt_targets = self.prompt_targets.read().unwrap();
|
||||
let prompt_target = prompt_targets.get(&prompt_target_name).unwrap();
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
let mut embedding_response: CreateEmbeddingResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error deserializing embedding response. body: {:?}: {:?}",
|
||||
String::from_utf8(body).unwrap(),
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let embeddings = embedding_response.data.remove(0).embedding;
|
||||
log::info!(
|
||||
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
|
||||
prompt_target.name,
|
||||
prompt_target.description,
|
||||
embedding_type
|
||||
);
|
||||
|
||||
embeddings_store().write().unwrap().insert(
|
||||
prompt_target.name.clone(),
|
||||
HashMap::from([(embedding_type, embeddings)]),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
panic!("No body in response");
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Context for FilterContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
debug!(
|
||||
"filter_context: on_http_call_response called with token_id: {:?}",
|
||||
token_id
|
||||
);
|
||||
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
|
||||
self.embedding_response_handler(
|
||||
body_size,
|
||||
callout_data.embedding_type,
|
||||
callout_data.prompt_target,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
|
||||
|
||||
if let Some(overrides_config) = self
|
||||
.config
|
||||
.as_mut()
|
||||
.and_then(|config| config.overrides.as_mut())
|
||||
{
|
||||
self.overrides = Rc::new(Some(std::mem::take(overrides_config)));
|
||||
}
|
||||
|
||||
for pt in self.config.clone().unwrap().prompt_targets {
|
||||
self.prompt_targets
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(pt.name.clone(), pt.clone());
|
||||
}
|
||||
|
||||
debug!("set configuration object");
|
||||
|
||||
if let Some(ratelimits_config) = self
|
||||
.config
|
||||
.as_mut()
|
||||
.and_then(|config| config.ratelimits.as_mut())
|
||||
{
|
||||
ratelimit::ratelimits(Some(std::mem::take(ratelimits_config)));
|
||||
}
|
||||
|
||||
if let Some(prompt_guards) = self
|
||||
.config
|
||||
.as_mut()
|
||||
.and_then(|config| config.prompt_guards.as_mut())
|
||||
{
|
||||
self.prompt_guards = Rc::new(Some(std::mem::take(prompt_guards)));
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||
debug!(
|
||||
"||| create_http_context called with context_id: {:?} |||",
|
||||
context_id
|
||||
);
|
||||
Some(Box::new(StreamContext::new(
|
||||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
Rc::clone(&self.prompt_targets),
|
||||
Rc::clone(&self.prompt_guards),
|
||||
Rc::clone(&self.overrides),
|
||||
)))
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Option<ContextType> {
|
||||
Some(ContextType::HttpContext)
|
||||
}
|
||||
|
||||
fn on_vm_start(&mut self, _: usize) -> bool {
|
||||
self.set_tick_period(Duration::from_secs(1));
|
||||
true
|
||||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
self.process_prompt_targets();
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
}
|
||||
19
arch/src/lib.rs
Normal file
19
arch/src/lib.rs
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
use filter_context::FilterContext;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
mod consts;
|
||||
mod filter_context;
|
||||
mod llm_providers;
|
||||
mod ratelimit;
|
||||
mod routing;
|
||||
mod stats;
|
||||
mod stream_context;
|
||||
mod tokenizer;
|
||||
|
||||
proxy_wasm::main! {{
|
||||
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
|
||||
Box::new(FilterContext::new())
|
||||
});
|
||||
}}
|
||||
47
arch/src/llm_providers.rs
Normal file
47
arch/src/llm_providers.rs
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
#[non_exhaustive]
|
||||
pub struct LlmProviders;
|
||||
|
||||
impl LlmProviders {
|
||||
pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider {
|
||||
name: "openai",
|
||||
api_key_header: "x-arch-openai-api-key",
|
||||
model: "gpt-3.5-turbo",
|
||||
};
|
||||
pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider {
|
||||
name: "mistral",
|
||||
api_key_header: "x-arch-mistral-api-key",
|
||||
model: "mistral-large-latest",
|
||||
};
|
||||
|
||||
pub const VARIANTS: &'static [LlmProvider<'static>] =
|
||||
&[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER];
|
||||
}
|
||||
|
||||
pub struct LlmProvider<'prov> {
|
||||
name: &'prov str,
|
||||
api_key_header: &'prov str,
|
||||
model: &'prov str,
|
||||
}
|
||||
|
||||
impl AsRef<str> for LlmProvider<'_> {
|
||||
fn as_ref(&self) -> &str {
|
||||
self.name
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LlmProvider<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name)
|
||||
}
|
||||
}
|
||||
|
||||
impl LlmProvider<'_> {
|
||||
pub fn api_key_header(&self) -> &str {
|
||||
self.api_key_header
|
||||
}
|
||||
|
||||
pub fn choose_model(&self) -> &str {
|
||||
// In the future this can be a more complex function balancing reliability, cost, performance, etc.
|
||||
self.model
|
||||
}
|
||||
}
|
||||
426
arch/src/ratelimit.rs
Normal file
426
arch/src/ratelimit.rs
Normal file
|
|
@ -0,0 +1,426 @@
|
|||
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
|
||||
use log::debug;
|
||||
use public_types::configuration;
|
||||
use public_types::configuration::{Limit, Ratelimit, TimeUnit};
|
||||
use std::num::{NonZero, NonZeroU32};
|
||||
use std::sync::RwLock;
|
||||
use std::{collections::HashMap, sync::OnceLock};
|
||||
|
||||
pub type RatelimitData = RwLock<RatelimitMap>;
|
||||
|
||||
pub fn ratelimits(ratelimits_config: Option<Vec<Ratelimit>>) -> &'static RatelimitData {
|
||||
static RATELIMIT_DATA: OnceLock<RatelimitData> = OnceLock::new();
|
||||
RATELIMIT_DATA.get_or_init(|| {
|
||||
RwLock::new(RatelimitMap::new(
|
||||
ratelimits_config.expect("The initialization call has to have passed a config"),
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
// The Data Structure is laid out in the following way:
|
||||
// Provider -> Hash { Header -> Limit }.
|
||||
// If the Header used to configure the given Limit:
|
||||
// a) Has None value, then there will be N Limit keyed by the Header value.
|
||||
// b) Has Some() value, then there will be 1 Limit keyed by the empty string.
|
||||
// It would have been nicer to use a non-keyed limit for b). However, the type system made that option a nightmare.
|
||||
pub struct RatelimitMap {
|
||||
datastore: HashMap<String, HashMap<configuration::Header, DefaultKeyedRateLimiter<String>>>,
|
||||
}
|
||||
|
||||
// This version of Header demands that the user passes a header value to match on.
|
||||
#[allow(unused)]
|
||||
#[derive(Debug)]
|
||||
pub struct Header {
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
impl From<Header> for configuration::Header {
|
||||
fn from(header: Header) -> Self {
|
||||
Self {
|
||||
key: header.key,
|
||||
value: Some(header.value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RatelimitMap {
|
||||
// n.b new is private so that the only access to the Ratelimits can be done via the static
|
||||
// reference inside a RwLock via ratelimit::ratelimits().
|
||||
fn new(ratelimits_config: Vec<Ratelimit>) -> Self {
|
||||
let mut new_ratelimit_map = RatelimitMap {
|
||||
datastore: HashMap::new(),
|
||||
};
|
||||
for ratelimit_config in ratelimits_config {
|
||||
let limit = DefaultKeyedRateLimiter::keyed(get_quota(ratelimit_config.limit));
|
||||
|
||||
match new_ratelimit_map
|
||||
.datastore
|
||||
.get_mut(&ratelimit_config.provider)
|
||||
{
|
||||
Some(limits) => match limits.get_mut(&ratelimit_config.selector) {
|
||||
Some(_) => {
|
||||
panic!("repeated selector. Selectors per provider must be unique")
|
||||
}
|
||||
None => {
|
||||
limits.insert(ratelimit_config.selector, limit);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
// The provider has not been seen before.
|
||||
// Insert the provider and a new HashMap with the specified limit
|
||||
let new_hash_map = HashMap::from([(ratelimit_config.selector, limit)]);
|
||||
new_ratelimit_map
|
||||
.datastore
|
||||
.insert(ratelimit_config.provider, new_hash_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
new_ratelimit_map
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn check_limit(
|
||||
&self,
|
||||
provider: String,
|
||||
selector: Header,
|
||||
tokens_used: NonZeroU32,
|
||||
) -> Result<(), String> {
|
||||
debug!(
|
||||
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
|
||||
provider, selector, tokens_used
|
||||
);
|
||||
|
||||
let provider_limits = match self.datastore.get(&provider) {
|
||||
None => {
|
||||
// No limit configured for this provider, hence ok.
|
||||
return Ok(());
|
||||
}
|
||||
Some(limit) => limit,
|
||||
};
|
||||
|
||||
let mut config_selector = configuration::Header::from(selector);
|
||||
|
||||
let (limit, limit_key) = match provider_limits.get(&config_selector) {
|
||||
// This is a specific limit, i.e one that was configured with both key, and value.
|
||||
// Therefore, the key for the internal limit does not matter, and hence the empty string is always returned.
|
||||
Some(limit) => (limit, String::from("")),
|
||||
None => {
|
||||
// Unwrap is ok here because we _know_ the value exists.
|
||||
let header_key = config_selector.value.take().unwrap();
|
||||
// Search for less specific limit, i.e, one that was configured without a value, therefore every Header
|
||||
// value has its own key in the internal limit.
|
||||
match provider_limits.get(&config_selector) {
|
||||
Some(limit) => (limit, header_key),
|
||||
// No limit for that header key, value pair exists within that provider limits.
|
||||
None => {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match limit.check_key_n(&limit_key, tokens_used) {
|
||||
Ok(Ok(())) => Ok(()),
|
||||
Ok(Err(_)) => Err(String::from("Not allowed")),
|
||||
Err(InsufficientCapacity(_)) => Err(String::from("Not allowed")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_quota(limit: Limit) -> Quota {
|
||||
let tokens = NonZero::new(limit.tokens).expect("Limit's tokens must be positive");
|
||||
match limit.unit {
|
||||
TimeUnit::Second => Quota::per_second(tokens),
|
||||
TimeUnit::Minute => Quota::per_minute(tokens),
|
||||
TimeUnit::Hour => Quota::per_hour(tokens),
|
||||
}
|
||||
}
|
||||
|
||||
// The following tests are inside the ratelimit module in order to access RatelimitMap::new() in order to provide
|
||||
// different configuration values per test.
|
||||
#[test]
|
||||
fn non_existent_provider_is_ok() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 100,
|
||||
unit: TimeUnit::Minute,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("non-existent-provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_existent_key_is_ok() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 100,
|
||||
unit: TimeUnit::Minute,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn specific_limit_does_not_catch_non_specific_value() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 200,
|
||||
unit: TimeUnit::Second,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("not-the-correct-value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn specific_limit_is_hit() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 200,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_specific_key_has_different_limits_for_different_values() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 100,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
// Value1 takes 50.
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("only-key"),
|
||||
value: String::from("value1"),
|
||||
},
|
||||
NonZero::new(50).unwrap(),
|
||||
)
|
||||
.is_ok());
|
||||
|
||||
// value2 takes 60 because it has its own 100 limit
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("only-key"),
|
||||
value: String::from("value2"),
|
||||
},
|
||||
NonZero::new(60).unwrap(),
|
||||
)
|
||||
.is_ok());
|
||||
|
||||
// However value1 cannot take more than 100 per hour which 50+70 = 120
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("only-key"),
|
||||
value: String::from("value1"),
|
||||
},
|
||||
NonZero::new(70).unwrap(),
|
||||
)
|
||||
.is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_provider_can_have_different_limits_with_the_same_keys() {
|
||||
let ratelimits_config = vec![
|
||||
Ratelimit {
|
||||
provider: String::from("first_provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 100,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
},
|
||||
Ratelimit {
|
||||
provider: String::from("second_provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 200,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("first_provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(100).unwrap(),
|
||||
)
|
||||
.is_ok());
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("second_provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(200).unwrap(),
|
||||
)
|
||||
.is_ok());
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("first_provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(1).unwrap(),
|
||||
)
|
||||
.is_err());
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("second_provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(1).unwrap(),
|
||||
)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
// These tests use the publicly exposed static singleton, thus the same configuration is used in every test.
|
||||
// If more tests are written here, move the initial call out of the test.
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::ratelimits;
|
||||
use configuration::{Limit, Ratelimit, TimeUnit};
|
||||
use public_types::configuration;
|
||||
use std::num::NonZero;
|
||||
use std::thread;
|
||||
|
||||
#[test]
|
||||
fn different_threads_have_same_ratelimit_data_structure() {
|
||||
let ratelimits_config = Some(vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 200,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
}]);
|
||||
|
||||
// Initialize in the main thread.
|
||||
ratelimits(ratelimits_config);
|
||||
|
||||
// Use the singleton in a different thread.
|
||||
thread::spawn(|| {
|
||||
let ratelimits = ratelimits(None);
|
||||
|
||||
assert!(ratelimits
|
||||
.read()
|
||||
.unwrap()
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
super::Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_err())
|
||||
});
|
||||
}
|
||||
}
|
||||
13
arch/src/routing.rs
Normal file
13
arch/src/routing.rs
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
use crate::llm_providers::{LlmProvider, LlmProviders};
|
||||
use rand::{seq::SliceRandom, thread_rng};
|
||||
|
||||
pub fn get_llm_provider<'hostname>(deterministic: bool) -> &'static LlmProvider<'hostname> {
|
||||
if deterministic {
|
||||
&LlmProviders::OPENAI_PROVIDER
|
||||
} else {
|
||||
let mut rng = thread_rng();
|
||||
LlmProviders::VARIANTS
|
||||
.choose(&mut rng)
|
||||
.expect("There should always be at least one llm provider")
|
||||
}
|
||||
}
|
||||
102
arch/src/stats.rs
Normal file
102
arch/src/stats.rs
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
use log::error;
|
||||
use proxy_wasm::hostcalls;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait Metric {
|
||||
fn id(&self) -> u32;
|
||||
fn value(&self) -> Result<u64, String> {
|
||||
match hostcalls::get_metric(self.id()) {
|
||||
Ok(value) => Ok(value),
|
||||
Err(Status::NotFound) => Err(format!("metric not found: {}", self.id())),
|
||||
Err(err) => Err(format!("unexpected status: {:?}", err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait IncrementingMetric: Metric {
|
||||
fn increment(&self, offset: i64) {
|
||||
match hostcalls::increment_metric(self.id(), offset) {
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("error incrementing metric: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait RecordingMetric: Metric {
|
||||
fn record(&self, value: u64) {
|
||||
match hostcalls::record_metric(self.id(), value) {
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("error recording metric: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Counter {
|
||||
id: u32,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl Counter {
|
||||
pub fn new(name: String) -> Counter {
|
||||
let returned_id = hostcalls::define_metric(MetricType::Counter, &name)
|
||||
.expect("failed to define counter '{}', name");
|
||||
Counter { id: returned_id }
|
||||
}
|
||||
}
|
||||
|
||||
impl Metric for Counter {
|
||||
fn id(&self) -> u32 {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl IncrementingMetric for Counter {}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Gauge {
|
||||
id: u32,
|
||||
}
|
||||
|
||||
impl Gauge {
|
||||
pub fn new(name: String) -> Gauge {
|
||||
let returned_id = hostcalls::define_metric(MetricType::Gauge, &name)
|
||||
.expect("failed to define gauge '{}', name");
|
||||
Gauge { id: returned_id }
|
||||
}
|
||||
}
|
||||
|
||||
impl Metric for Gauge {
|
||||
fn id(&self) -> u32 {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
/// For state of the world updates
|
||||
impl RecordingMetric for Gauge {}
|
||||
/// For offset deltas
|
||||
impl IncrementingMetric for Gauge {}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct Histogram {
|
||||
id: u32,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl Histogram {
|
||||
pub fn new(name: String) -> Histogram {
|
||||
let returned_id = hostcalls::define_metric(MetricType::Histogram, &name)
|
||||
.expect("failed to define histogram '{}', name");
|
||||
Histogram { id: returned_id }
|
||||
}
|
||||
}
|
||||
|
||||
impl Metric for Histogram {
|
||||
fn id(&self) -> u32 {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordingMetric for Histogram {}
|
||||
1098
arch/src/stream_context.rs
Normal file
1098
arch/src/stream_context.rs
Normal file
File diff suppressed because it is too large
Load diff
39
arch/src/tokenizer.rs
Normal file
39
arch/src/tokenizer.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
use log::debug;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
#[allow(dead_code)]
|
||||
pub enum Error {
|
||||
UnknownModel,
|
||||
FailedToTokenize,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
|
||||
debug!("getting token count model={}", model_name);
|
||||
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel)?;
|
||||
Ok(bpe.encode_ordinary(text).len())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn encode_ordinary() {
|
||||
let model_name = "gpt-3.5-turbo";
|
||||
let text = "How many tokens does this sentence have?";
|
||||
assert_eq!(
|
||||
8,
|
||||
token_count(model_name, text).expect("correct tokenization")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unrecognized_model() {
|
||||
assert_eq!(
|
||||
Error::UnknownModel,
|
||||
token_count("unknown", "").expect_err("unknown model")
|
||||
)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue