mirror of
https://github.com/katanemo/plano.git
synced 2026-05-27 14:17:15 +02:00
split wasm filter
This commit is contained in:
parent
b1746b38b4
commit
0e04b09f56
44 changed files with 6009 additions and 272 deletions
2164
crates/llm_gateway/Cargo.lock
generated
Normal file
2164
crates/llm_gateway/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
29
crates/llm_gateway/Cargo.toml
Normal file
29
crates/llm_gateway/Cargo.toml
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
[package]
|
||||
name = "llm_gateway"
|
||||
version = "0.1.0"
|
||||
authors = ["Katanemo Inc <info@katanemo.com>"]
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
proxy-wasm = "0.2.1"
|
||||
log = "0.4"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_yaml = "0.9.34"
|
||||
serde_json = "1.0"
|
||||
md5 = "0.7.0"
|
||||
public_types = { path = "../public_types" }
|
||||
http = "1.1.0"
|
||||
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
|
||||
tiktoken-rs = "0.5.9"
|
||||
acap = "0.3.0"
|
||||
rand = "0.8.5"
|
||||
thiserror = "1.0.64"
|
||||
derivative = "2.2.0"
|
||||
sha2 = "0.10.8"
|
||||
|
||||
[dev-dependencies]
|
||||
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
|
||||
serial_test = "3.1.1"
|
||||
324
crates/llm_gateway/src/filter_context.rs
Normal file
324
crates/llm_gateway/src/filter_context.rs
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
use crate::llm_providers::LlmProviders;
|
||||
use crate::ratelimit;
|
||||
use crate::stream_context::StreamContext;
|
||||
use log::debug;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::EmbeddingType;
|
||||
use public_types::configuration::{
|
||||
Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget,
|
||||
};
|
||||
use public_types::consts::ARCH_INTERNAL_CLUSTER_NAME;
|
||||
use public_types::consts::ARCH_UPSTREAM_HOST_HEADER;
|
||||
use public_types::consts::DEFAULT_EMBEDDING_MODEL;
|
||||
use public_types::consts::MODEL_SERVER_NAME;
|
||||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use public_types::http::CallArgs;
|
||||
use public_types::http::Client;
|
||||
use public_types::stats::Counter;
|
||||
use public_types::stats::Gauge;
|
||||
use public_types::stats::IncrementingMetric;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct WasmMetrics {
|
||||
pub active_http_calls: Gauge,
|
||||
pub ratelimited_rq: Counter,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
WasmMetrics {
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
ratelimited_rq: Counter::new(String::from("ratelimited_rq")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
|
||||
pub type EmbeddingsStore = HashMap<String, EmbeddingTypeMap>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterCallContext {
|
||||
pub prompt_target_name: String,
|
||||
pub embedding_type: EmbeddingType,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterContext {
|
||||
metrics: Rc<WasmMetrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: RefCell<HashMap<u32, FilterCallContext>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
mode: GatewayMode,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
llm_providers: Option<Rc<LlmProviders>>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
temp_embeddings_store: EmbeddingsStore,
|
||||
}
|
||||
|
||||
impl FilterContext {
|
||||
pub fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
metrics: Rc::new(WasmMetrics::new()),
|
||||
system_prompt: Rc::new(None),
|
||||
prompt_targets: Rc::new(HashMap::new()),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(PromptGuards::default()),
|
||||
mode: GatewayMode::Prompt,
|
||||
llm_providers: None,
|
||||
embeddings_store: Some(Rc::new(HashMap::new())),
|
||||
temp_embeddings_store: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_prompt_targets(&self) {
|
||||
for values in self.prompt_targets.iter() {
|
||||
let prompt_target = values.1;
|
||||
self.schedule_embeddings_call(
|
||||
&prompt_target.name,
|
||||
&prompt_target.description,
|
||||
EmbeddingType::Description,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn schedule_embeddings_call(
|
||||
&self,
|
||||
prompt_target_name: &str,
|
||||
input: &str,
|
||||
embedding_type: EmbeddingType,
|
||||
) {
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(String::from(input))),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
let json_data = serde_json::to_string(&embeddings_input).unwrap();
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
"/embeddings",
|
||||
vec![
|
||||
(ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME),
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", MODEL_SERVER_NAME),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(60),
|
||||
);
|
||||
|
||||
let call_context = crate::filter_context::FilterCallContext {
|
||||
prompt_target_name: String::from(prompt_target_name),
|
||||
embedding_type,
|
||||
};
|
||||
|
||||
if let Err(error) = self.http_call(call_args, call_context) {
|
||||
panic!("{error}")
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding_response_handler(
|
||||
&mut self,
|
||||
body_size: usize,
|
||||
embedding_type: EmbeddingType,
|
||||
prompt_target_name: String,
|
||||
) {
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(&prompt_target_name)
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"Received embeddings response for unknown prompt target name={}",
|
||||
prompt_target_name
|
||||
)
|
||||
});
|
||||
|
||||
let body = self
|
||||
.get_http_call_response_body(0, body_size)
|
||||
.expect("No body in response");
|
||||
if !body.is_empty() {
|
||||
let mut embedding_response: CreateEmbeddingResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error deserializing embedding response. body: {:?}: {:?}",
|
||||
String::from_utf8(body).unwrap(),
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let embeddings = embedding_response.data.remove(0).embedding;
|
||||
debug!(
|
||||
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
|
||||
prompt_target.name,
|
||||
prompt_target.description,
|
||||
embedding_type
|
||||
);
|
||||
|
||||
let entry = self.temp_embeddings_store.entry(prompt_target_name);
|
||||
match entry {
|
||||
Entry::Occupied(_) => {
|
||||
entry.and_modify(|e| {
|
||||
if let Entry::Vacant(e) = e.entry(embedding_type) {
|
||||
e.insert(embeddings);
|
||||
} else {
|
||||
panic!(
|
||||
"Duplicate {:?} for prompt target with name=\"{}\"",
|
||||
&embedding_type, prompt_target.name
|
||||
)
|
||||
}
|
||||
});
|
||||
}
|
||||
Entry::Vacant(_) => {
|
||||
entry.or_insert(HashMap::from([(embedding_type, embeddings)]));
|
||||
}
|
||||
}
|
||||
|
||||
if self.prompt_targets.len() == self.temp_embeddings_store.len() {
|
||||
self.embeddings_store =
|
||||
Some(Rc::new(std::mem::take(&mut self.temp_embeddings_store)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Client for FilterContext {
|
||||
type CallContext = FilterCallContext;
|
||||
|
||||
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
|
||||
&self.callouts
|
||||
}
|
||||
|
||||
fn active_http_calls(&self) -> &Gauge {
|
||||
&self.metrics.active_http_calls
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for FilterContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
debug!(
|
||||
"filter_context: on_http_call_response called with token_id: {:?}",
|
||||
token_id
|
||||
);
|
||||
let callout_data = self
|
||||
.callouts
|
||||
.borrow_mut()
|
||||
.remove(&token_id)
|
||||
.expect("invalid token_id");
|
||||
|
||||
self.metrics.active_http_calls.increment(-1);
|
||||
|
||||
self.embedding_response_handler(
|
||||
body_size,
|
||||
callout_data.embedding_type,
|
||||
callout_data.prompt_target_name,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
let config_bytes = self
|
||||
.get_plugin_configuration()
|
||||
.expect("Arch config cannot be empty");
|
||||
|
||||
let config: Configuration = match serde_yaml::from_slice(&config_bytes) {
|
||||
Ok(config) => config,
|
||||
Err(err) => panic!("Invalid arch config \"{:?}\"", err),
|
||||
};
|
||||
|
||||
self.overrides = Rc::new(config.overrides);
|
||||
|
||||
let mut prompt_targets = HashMap::new();
|
||||
for pt in config.prompt_targets {
|
||||
prompt_targets.insert(pt.name.clone(), pt.clone());
|
||||
}
|
||||
self.system_prompt = Rc::new(config.system_prompt);
|
||||
self.prompt_targets = Rc::new(prompt_targets);
|
||||
self.mode = config.mode.unwrap_or_default();
|
||||
|
||||
ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default()));
|
||||
|
||||
if let Some(prompt_guards) = config.prompt_guards {
|
||||
self.prompt_guards = Rc::new(prompt_guards)
|
||||
}
|
||||
|
||||
match config.llm_providers.try_into() {
|
||||
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
|
||||
Err(err) => panic!("{err}"),
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||
debug!(
|
||||
"||| create_http_context called with context_id: {:?} |||",
|
||||
context_id
|
||||
);
|
||||
|
||||
// No StreamContext can be created until the Embedding Store is fully initialized.
|
||||
let embedding_store = match self.mode {
|
||||
GatewayMode::Llm => None,
|
||||
GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())),
|
||||
};
|
||||
Some(Box::new(StreamContext::new(
|
||||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
Rc::clone(&self.system_prompt),
|
||||
Rc::clone(&self.prompt_targets),
|
||||
Rc::clone(&self.prompt_guards),
|
||||
Rc::clone(&self.overrides),
|
||||
Rc::clone(
|
||||
self.llm_providers
|
||||
.as_ref()
|
||||
.expect("LLM Providers must exist when Streams are being created"),
|
||||
),
|
||||
embedding_store,
|
||||
self.mode.clone(),
|
||||
)))
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Option<ContextType> {
|
||||
Some(ContextType::HttpContext)
|
||||
}
|
||||
|
||||
fn on_vm_start(&mut self, _: usize) -> bool {
|
||||
self.set_tick_period(Duration::from_secs(1));
|
||||
true
|
||||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
debug!("starting up arch filter in mode: {:?}", self.mode);
|
||||
if self.mode == GatewayMode::Prompt {
|
||||
self.process_prompt_targets();
|
||||
}
|
||||
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
}
|
||||
17
crates/llm_gateway/src/lib.rs
Normal file
17
crates/llm_gateway/src/lib.rs
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
use filter_context::FilterContext;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
mod filter_context;
|
||||
mod llm_providers;
|
||||
mod ratelimit;
|
||||
mod routing;
|
||||
mod stream_context;
|
||||
mod tokenizer;
|
||||
|
||||
proxy_wasm::main! {{
|
||||
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
|
||||
Box::new(FilterContext::new())
|
||||
});
|
||||
}}
|
||||
69
crates/llm_gateway/src/llm_providers.rs
Normal file
69
crates/llm_gateway/src/llm_providers.rs
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
use public_types::configuration::LlmProvider;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LlmProviders {
|
||||
providers: HashMap<String, Rc<LlmProvider>>,
|
||||
default: Option<Rc<LlmProvider>>,
|
||||
}
|
||||
|
||||
impl LlmProviders {
|
||||
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc<LlmProvider>> {
|
||||
self.providers.iter()
|
||||
}
|
||||
|
||||
pub fn default(&self) -> Option<Rc<LlmProvider>> {
|
||||
self.default.as_ref().map(|rc| rc.clone())
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
|
||||
self.providers.get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum LlmProvidersNewError {
|
||||
#[error("There must be at least one LLM Provider")]
|
||||
EmptySource,
|
||||
#[error("There must be at most one default LLM Provider")]
|
||||
MoreThanOneDefault,
|
||||
#[error("\'{0}\' is not a unique name")]
|
||||
DuplicateName(String),
|
||||
}
|
||||
|
||||
impl TryFrom<Vec<LlmProvider>> for LlmProviders {
|
||||
type Error = LlmProvidersNewError;
|
||||
|
||||
fn try_from(llm_providers_config: Vec<LlmProvider>) -> Result<Self, Self::Error> {
|
||||
if llm_providers_config.is_empty() {
|
||||
return Err(LlmProvidersNewError::EmptySource);
|
||||
}
|
||||
|
||||
let mut llm_providers = LlmProviders {
|
||||
providers: HashMap::new(),
|
||||
default: None,
|
||||
};
|
||||
|
||||
for llm_provider in llm_providers_config {
|
||||
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
|
||||
if llm_provider.default.unwrap_or_default() {
|
||||
match llm_providers.default {
|
||||
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
|
||||
None => llm_providers.default = Some(Rc::clone(&llm_provider)),
|
||||
}
|
||||
}
|
||||
|
||||
// Insert and check that there is no other provider with the same name.
|
||||
let name = llm_provider.name.clone();
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(name.clone(), llm_provider)
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
}
|
||||
}
|
||||
Ok(llm_providers)
|
||||
}
|
||||
}
|
||||
450
crates/llm_gateway/src/ratelimit.rs
Normal file
450
crates/llm_gateway/src/ratelimit.rs
Normal file
|
|
@ -0,0 +1,450 @@
|
|||
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
|
||||
use log::debug;
|
||||
use public_types::configuration;
|
||||
use public_types::configuration::{Limit, Ratelimit, TimeUnit};
|
||||
use std::fmt::Display;
|
||||
use std::num::{NonZero, NonZeroU32};
|
||||
use std::sync::RwLock;
|
||||
use std::{collections::HashMap, sync::OnceLock};
|
||||
|
||||
pub type RatelimitData = RwLock<RatelimitMap>;
|
||||
|
||||
pub fn ratelimits(ratelimits_config: Option<Vec<Ratelimit>>) -> &'static RatelimitData {
|
||||
static RATELIMIT_DATA: OnceLock<RatelimitData> = OnceLock::new();
|
||||
RATELIMIT_DATA.get_or_init(|| {
|
||||
RwLock::new(RatelimitMap::new(
|
||||
ratelimits_config.expect("The initialization call has to have passed a config"),
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
// The Data Structure is laid out in the following way:
|
||||
// Provider -> Hash { Header -> Limit }.
|
||||
// If the Header used to configure the given Limit:
|
||||
// a) Has None value, then there will be N Limit keyed by the Header value.
|
||||
// b) Has Some() value, then there will be 1 Limit keyed by the empty string.
|
||||
// It would have been nicer to use a non-keyed limit for b). However, the type system made that option a nightmare.
|
||||
pub struct RatelimitMap {
|
||||
datastore: HashMap<String, HashMap<configuration::Header, DefaultKeyedRateLimiter<String>>>,
|
||||
}
|
||||
|
||||
// This version of Header demands that the user passes a header value to match on.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Header {
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
impl Display for Header {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{self:?}")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Header> for configuration::Header {
|
||||
fn from(header: Header) -> Self {
|
||||
Self {
|
||||
key: header.key,
|
||||
value: Some(header.value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error("exceeded limit provider={provider}, selector={selector}, tokens_used={tokens_used}")]
|
||||
ExceededLimit {
|
||||
provider: String,
|
||||
selector: Header,
|
||||
tokens_used: NonZeroU32,
|
||||
},
|
||||
}
|
||||
|
||||
impl RatelimitMap {
|
||||
// n.b new is private so that the only access to the Ratelimits can be done via the static
|
||||
// reference inside a RwLock via ratelimit::ratelimits().
|
||||
fn new(ratelimits_config: Vec<Ratelimit>) -> Self {
|
||||
let mut new_ratelimit_map = RatelimitMap {
|
||||
datastore: HashMap::new(),
|
||||
};
|
||||
for ratelimit_config in ratelimits_config {
|
||||
let limit = DefaultKeyedRateLimiter::keyed(get_quota(ratelimit_config.limit));
|
||||
|
||||
match new_ratelimit_map.datastore.get_mut(&ratelimit_config.model) {
|
||||
Some(limits) => match limits.get_mut(&ratelimit_config.selector) {
|
||||
Some(_) => {
|
||||
panic!("repeated selector. Selectors per provider must be unique")
|
||||
}
|
||||
None => {
|
||||
limits.insert(ratelimit_config.selector, limit);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
// The provider has not been seen before.
|
||||
// Insert the provider and a new HashMap with the specified limit
|
||||
let new_hash_map = HashMap::from([(ratelimit_config.selector, limit)]);
|
||||
new_ratelimit_map
|
||||
.datastore
|
||||
.insert(ratelimit_config.model, new_hash_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
new_ratelimit_map
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn check_limit(
|
||||
&self,
|
||||
provider: String,
|
||||
selector: Header,
|
||||
tokens_used: NonZeroU32,
|
||||
) -> Result<(), Error> {
|
||||
debug!(
|
||||
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
|
||||
provider, selector, tokens_used
|
||||
);
|
||||
|
||||
let provider_limits = match self.datastore.get(&provider) {
|
||||
None => {
|
||||
// No limit configured for this provider, hence ok.
|
||||
return Ok(());
|
||||
}
|
||||
Some(limit) => limit,
|
||||
};
|
||||
|
||||
let mut config_selector = configuration::Header::from(selector.clone());
|
||||
|
||||
let (limit, limit_key) = match provider_limits.get(&config_selector) {
|
||||
// This is a specific limit, i.e one that was configured with both key, and value.
|
||||
// Therefore, the key for the internal limit does not matter, and hence the empty string is always returned.
|
||||
Some(limit) => (limit, String::from("")),
|
||||
None => {
|
||||
// Unwrap is ok here because we _know_ the value exists.
|
||||
let header_key = config_selector.value.take().unwrap();
|
||||
// Search for less specific limit, i.e, one that was configured without a value, therefore every Header
|
||||
// value has its own key in the internal limit.
|
||||
match provider_limits.get(&config_selector) {
|
||||
Some(limit) => (limit, header_key),
|
||||
// No limit for that header key, value pair exists within that provider limits.
|
||||
None => {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match limit.check_key_n(&limit_key, tokens_used) {
|
||||
Ok(Ok(())) => Ok(()),
|
||||
Ok(Err(_)) | Err(InsufficientCapacity(_)) => Err(Error::ExceededLimit {
|
||||
provider,
|
||||
selector,
|
||||
tokens_used,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_quota(limit: Limit) -> Quota {
|
||||
let tokens = NonZero::new(limit.tokens).expect("Limit's tokens must be positive");
|
||||
match limit.unit {
|
||||
TimeUnit::Second => Quota::per_second(tokens),
|
||||
TimeUnit::Minute => Quota::per_minute(tokens),
|
||||
TimeUnit::Hour => Quota::per_hour(tokens),
|
||||
}
|
||||
}
|
||||
|
||||
// The following tests are inside the ratelimit module in order to access RatelimitMap::new() in order to provide
|
||||
// different configuration values per test.
|
||||
#[test]
|
||||
fn non_existent_provider_is_ok() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 100,
|
||||
unit: TimeUnit::Minute,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("non-existent-provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_existent_key_is_ok() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 100,
|
||||
unit: TimeUnit::Minute,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn specific_limit_does_not_catch_non_specific_value() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 200,
|
||||
unit: TimeUnit::Second,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("not-the-correct-value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn specific_limit_is_hit() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 200,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_specific_key_has_different_limits_for_different_values() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 100,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
// Value1 takes 50.
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("only-key"),
|
||||
value: String::from("value1"),
|
||||
},
|
||||
NonZero::new(50).unwrap(),
|
||||
)
|
||||
.is_ok());
|
||||
|
||||
// value2 takes 60 because it has its own 100 limit
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("only-key"),
|
||||
value: String::from("value2"),
|
||||
},
|
||||
NonZero::new(60).unwrap(),
|
||||
)
|
||||
.is_ok());
|
||||
|
||||
// However value1 cannot take more than 100 per hour which 50+70 = 120
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("only-key"),
|
||||
value: String::from("value1"),
|
||||
},
|
||||
NonZero::new(70).unwrap(),
|
||||
)
|
||||
.is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_provider_can_have_different_limits_with_the_same_keys() {
|
||||
let ratelimits_config = vec![
|
||||
Ratelimit {
|
||||
model: String::from("first_provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 100,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
},
|
||||
Ratelimit {
|
||||
model: String::from("second_provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 200,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("first_provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(100).unwrap(),
|
||||
)
|
||||
.is_ok());
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("second_provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(200).unwrap(),
|
||||
)
|
||||
.is_ok());
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("first_provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(1).unwrap(),
|
||||
)
|
||||
.is_err());
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("second_provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(1).unwrap(),
|
||||
)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
// These tests use the publicly exposed static singleton, thus the same configuration is used in every test.
|
||||
// If more tests are written here, move the initial call out of the test.
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::ratelimits;
|
||||
use configuration::{Limit, Ratelimit, TimeUnit};
|
||||
use public_types::configuration;
|
||||
use std::num::NonZero;
|
||||
use std::thread;
|
||||
|
||||
#[test]
|
||||
fn make_ratelimits_optional() {
|
||||
let ratelimits_config = Vec::new();
|
||||
|
||||
// Initialize in the main thread.
|
||||
ratelimits(Some(ratelimits_config));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_threads_have_same_ratelimit_data_structure() {
|
||||
let ratelimits_config = Some(vec![Ratelimit {
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 200,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
}]);
|
||||
|
||||
// Initialize in the main thread.
|
||||
ratelimits(ratelimits_config);
|
||||
|
||||
// Use the singleton in a different thread.
|
||||
thread::spawn(|| {
|
||||
let ratelimits = ratelimits(None);
|
||||
|
||||
assert!(ratelimits
|
||||
.read()
|
||||
.unwrap()
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
super::Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_err())
|
||||
});
|
||||
}
|
||||
}
|
||||
50
crates/llm_gateway/src/routing.rs
Normal file
50
crates/llm_gateway/src/routing.rs
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
use std::rc::Rc;
|
||||
|
||||
use crate::llm_providers::LlmProviders;
|
||||
use log::debug;
|
||||
use public_types::configuration::LlmProvider;
|
||||
use rand::{seq::IteratorRandom, thread_rng};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ProviderHint {
|
||||
Default,
|
||||
Name(String),
|
||||
}
|
||||
|
||||
impl From<String> for ProviderHint {
|
||||
fn from(value: String) -> Self {
|
||||
match value.as_str() {
|
||||
"default" => ProviderHint::Default,
|
||||
_ => ProviderHint::Name(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_llm_provider(
|
||||
llm_providers: &LlmProviders,
|
||||
provider_hint: Option<ProviderHint>,
|
||||
) -> Rc<LlmProvider> {
|
||||
let maybe_provider = provider_hint.and_then(|hint| match hint {
|
||||
ProviderHint::Default => llm_providers.default(),
|
||||
// FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST?
|
||||
ProviderHint::Name(name) => llm_providers.get(&name),
|
||||
});
|
||||
|
||||
if let Some(provider) = maybe_provider {
|
||||
return provider;
|
||||
}
|
||||
|
||||
if llm_providers.default().is_some() {
|
||||
debug!("no llm provider found for hint, using default llm provider");
|
||||
return llm_providers.default().unwrap();
|
||||
}
|
||||
|
||||
debug!("no default llm found, using random llm provider");
|
||||
let mut rng = thread_rng();
|
||||
llm_providers
|
||||
.iter()
|
||||
.choose(&mut rng)
|
||||
.expect("There should always be at least one llm provider")
|
||||
.1
|
||||
.clone()
|
||||
}
|
||||
1576
crates/llm_gateway/src/stream_context.rs
Normal file
1576
crates/llm_gateway/src/stream_context.rs
Normal file
File diff suppressed because it is too large
Load diff
39
crates/llm_gateway/src/tokenizer.rs
Normal file
39
crates/llm_gateway/src/tokenizer.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
use log::debug;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
#[allow(dead_code)]
|
||||
pub enum Error {
|
||||
UnknownModel,
|
||||
FailedToTokenize,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
|
||||
debug!("getting token count model={}", model_name);
|
||||
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel)?;
|
||||
Ok(bpe.encode_ordinary(text).len())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn encode_ordinary() {
|
||||
let model_name = "gpt-3.5-turbo";
|
||||
let text = "How many tokens does this sentence have?";
|
||||
assert_eq!(
|
||||
8,
|
||||
token_count(model_name, text).expect("correct tokenization")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unrecognized_model() {
|
||||
assert_eq!(
|
||||
Error::UnknownModel,
|
||||
token_count("unknown", "").expect_err("unknown model")
|
||||
)
|
||||
}
|
||||
}
|
||||
805
crates/llm_gateway/tests/integration.rs
Normal file
805
crates/llm_gateway/tests/integration.rs
Normal file
|
|
@ -0,0 +1,805 @@
|
|||
use http::StatusCode;
|
||||
use proxy_wasm_test_framework::tester::{self, Tester};
|
||||
use proxy_wasm_test_framework::types::{
|
||||
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
|
||||
};
|
||||
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
|
||||
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
|
||||
use public_types::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
|
||||
use public_types::embeddings::{
|
||||
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
|
||||
Embedding,
|
||||
};
|
||||
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
|
||||
use serde_yaml::Value;
|
||||
use serial_test::serial;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
fn wasm_module() -> String {
|
||||
let wasm_file = Path::new("target/wasm32-wasi/release/intelligent_prompt_gateway.wasm");
|
||||
assert!(
|
||||
wasm_file.exists(),
|
||||
"Run `cargo build --release --target=wasm32-wasi` first"
|
||||
);
|
||||
wasm_file.to_str().unwrap().to_string()
|
||||
}
|
||||
|
||||
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
||||
module
|
||||
.call_proxy_on_request_headers(http_context, 0, false)
|
||||
.expect_get_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-llm-provider-hint"),
|
||||
)
|
||||
.returning(Some("default"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-upstream"),
|
||||
Some("arch_llm_listener"),
|
||||
)
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-llm-provider"),
|
||||
Some("open-ai-gpt-4"),
|
||||
)
|
||||
.expect_replace_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("Authorization"),
|
||||
Some("Bearer secret_key"),
|
||||
)
|
||||
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
|
||||
.expect_get_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-ratelimit-selector"),
|
||||
)
|
||||
.returning(Some("selector-key"))
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("selector-key"))
|
||||
.returning(Some("selector-value"))
|
||||
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
|
||||
.returning(None)
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
||||
.returning(Some("/v1/chat/completions"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
|
||||
.returning(None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
request_headers_expectations(module, http_context);
|
||||
|
||||
// Request Body
|
||||
let chat_completions_request_body = "\
|
||||
{\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
],\
|
||||
\"model\": \"gpt-4\"\
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/guard"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(1))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
|
||||
let prompt_guard_response = PromptGuardResponse {
|
||||
toxic_prob: None,
|
||||
toxic_verdict: None,
|
||||
jailbreak_prob: None,
|
||||
jailbreak_verdict: None,
|
||||
};
|
||||
let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
1,
|
||||
0,
|
||||
prompt_guard_response_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&prompt_guard_response_buffer))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(2))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let embedding_response = CreateEmbeddingResponse {
|
||||
data: vec![Embedding {
|
||||
index: 0,
|
||||
embedding: vec![],
|
||||
object: embedding::Object::default(),
|
||||
}],
|
||||
model: String::from("test"),
|
||||
object: create_embedding_response::Object::default(),
|
||||
usage: Box::new(CreateEmbeddingResponseUsage::new(0, 0)),
|
||||
};
|
||||
let embeddings_response_buffer = serde_json::to_string(&embedding_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
2,
|
||||
0,
|
||||
embeddings_response_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&embeddings_response_buffer))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/zeroshot"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(3))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let zero_shot_response = ZeroShotClassificationResponse {
|
||||
predicted_class: "weather_forecast".to_string(),
|
||||
predicted_class_score: 0.1,
|
||||
scores: HashMap::new(),
|
||||
model: "test-model".to_string(),
|
||||
};
|
||||
let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
3,
|
||||
0,
|
||||
zeroshot_intent_detection_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&zeroshot_intent_detection_buffer))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
("x-arch-upstream", "arch_fc"),
|
||||
(":path", "/v1/chat/completions"),
|
||||
(":authority", "arch_fc"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "120000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(4))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn setup_filter(module: &mut Tester, config: &str) -> i32 {
|
||||
let filter_context = 1;
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(filter_context, 0)
|
||||
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
|
||||
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_configure(filter_context, config.len() as i32)
|
||||
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
|
||||
.returning(Some(config))
|
||||
.execute_and_expect(ReturnType::Bool(true))
|
||||
.unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_tick(filter_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(101))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_set_tick_period_millis(Some(0))
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let embedding_response = CreateEmbeddingResponse {
|
||||
data: vec![Embedding {
|
||||
embedding: vec![],
|
||||
index: 0,
|
||||
object: embedding::Object::default(),
|
||||
}],
|
||||
model: String::from("test"),
|
||||
object: create_embedding_response::Object::default(),
|
||||
usage: Box::new(CreateEmbeddingResponseUsage {
|
||||
prompt_tokens: 0,
|
||||
total_tokens: 0,
|
||||
}),
|
||||
};
|
||||
let embedding_response_str = serde_json::to_string(&embedding_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
filter_context,
|
||||
101,
|
||||
0,
|
||||
embedding_response_str.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_log(
|
||||
Some(LogLevel::Debug),
|
||||
Some(
|
||||
format!(
|
||||
"filter_context: on_http_call_response called with token_id: {:?}",
|
||||
101
|
||||
)
|
||||
.as_str(),
|
||||
),
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&embedding_response_str))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
filter_context
|
||||
}
|
||||
|
||||
fn default_config() -> &'static str {
|
||||
r#"
|
||||
version: "0.1-beta"
|
||||
|
||||
listener:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
endpoints:
|
||||
api_server:
|
||||
endpoint: api_server:80
|
||||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: open-ai-gpt-4
|
||||
provider: openai
|
||||
access_key: secret_key
|
||||
model: gpt-4
|
||||
default: true
|
||||
|
||||
overrides:
|
||||
# confidence threshold for prompt target intent matching
|
||||
prompt_target_intent_matching_threshold: 0.6
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception:
|
||||
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
|
||||
|
||||
prompt_targets:
|
||||
- name: weather_forecast
|
||||
description: This function provides realtime weather forecast information for a given city.
|
||||
parameters:
|
||||
- name: city
|
||||
required: true
|
||||
description: The city for which the weather forecast is requested.
|
||||
- name: days
|
||||
description: The number of days for which the weather forecast is requested.
|
||||
- name: units
|
||||
description: The units in which the weather forecast is requested.
|
||||
endpoint:
|
||||
name: api_server
|
||||
path: /weather
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
|
||||
ratelimits:
|
||||
- model: gpt-4
|
||||
selector:
|
||||
key: selector-key
|
||||
value: selector-value
|
||||
limit:
|
||||
tokens: 1
|
||||
unit: minute
|
||||
"#
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn successful_request_to_open_ai_chat_completions() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
request_headers_expectations(&mut module, http_context);
|
||||
|
||||
// Request Body
|
||||
let chat_completions_request_body = "\
|
||||
{\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
],\
|
||||
\"model\": \"gpt-4\"\
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("arch_internal"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn bad_request_to_open_ai_chat_completions() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
request_headers_expectations(&mut module, http_context);
|
||||
|
||||
// Request Body
|
||||
let incomplete_chat_completions_request_body = "\
|
||||
{\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
]\
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
incomplete_chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(incomplete_chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn request_ratelimited() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let arch_fc_resp = ChatCompletionsResponse {
|
||||
usage: Some(Usage {
|
||||
completion_tokens: 0,
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: "test".to_string(),
|
||||
index: 0,
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: String::from("test"),
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("city"),
|
||||
Value::String(String::from("seattle")),
|
||||
)]),
|
||||
},
|
||||
}]),
|
||||
model: None,
|
||||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&arch_fc_resp_str))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/hallucination"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let hallucatination_body = HallucinationClassificationResponse {
|
||||
params_scores: HashMap::from([("city".to_string(), 0.99)]),
|
||||
model: "nli-model".to_string(),
|
||||
};
|
||||
|
||||
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "api_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/weather"),
|
||||
(":authority", "api_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(6))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
|
||||
.returning(Some("200"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.expect_metric_increment("ratelimited_rq", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn request_not_ratelimited() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap();
|
||||
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
|
||||
let config_str = serde_json::to_string(&config).unwrap();
|
||||
|
||||
let filter_context = setup_filter(&mut module, &config_str);
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let arch_fc_resp = ChatCompletionsResponse {
|
||||
usage: Some(Usage {
|
||||
completion_tokens: 0,
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: "test".to_string(),
|
||||
index: 0,
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: String::from("test"),
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("city"),
|
||||
Value::String(String::from("seattle")),
|
||||
)]),
|
||||
},
|
||||
}]),
|
||||
model: None,
|
||||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&arch_fc_resp_str))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/hallucination"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// hallucination should return that parameters were not halliucinated
|
||||
// prompt: str
|
||||
// parameters: dict
|
||||
// model: str
|
||||
|
||||
let hallucatination_body = HallucinationClassificationResponse {
|
||||
params_scores: HashMap::from([("city".to_string(), 0.99)]),
|
||||
model: "nli-model".to_string(),
|
||||
};
|
||||
|
||||
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "api_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/weather"),
|
||||
(":authority", "api_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(6))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
|
||||
.returning(Some("200"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
2164
crates/prompt_gateway/Cargo.lock
generated
Normal file
2164
crates/prompt_gateway/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
29
crates/prompt_gateway/Cargo.toml
Normal file
29
crates/prompt_gateway/Cargo.toml
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
[package]
|
||||
name = "prompt_gateway"
|
||||
version = "0.1.0"
|
||||
authors = ["Katanemo Inc <info@katanemo.com>"]
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
proxy-wasm = "0.2.1"
|
||||
log = "0.4"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_yaml = "0.9.34"
|
||||
serde_json = "1.0"
|
||||
md5 = "0.7.0"
|
||||
public_types = { path = "../public_types" }
|
||||
http = "1.1.0"
|
||||
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
|
||||
tiktoken-rs = "0.5.9"
|
||||
acap = "0.3.0"
|
||||
rand = "0.8.5"
|
||||
thiserror = "1.0.64"
|
||||
derivative = "2.2.0"
|
||||
sha2 = "0.10.8"
|
||||
|
||||
[dev-dependencies]
|
||||
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
|
||||
serial_test = "3.1.1"
|
||||
324
crates/prompt_gateway/src/filter_context.rs
Normal file
324
crates/prompt_gateway/src/filter_context.rs
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
use crate::llm_providers::LlmProviders;
|
||||
use crate::ratelimit;
|
||||
use crate::stream_context::StreamContext;
|
||||
use log::debug;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::EmbeddingType;
|
||||
use public_types::configuration::{
|
||||
Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget,
|
||||
};
|
||||
use public_types::consts::ARCH_INTERNAL_CLUSTER_NAME;
|
||||
use public_types::consts::ARCH_UPSTREAM_HOST_HEADER;
|
||||
use public_types::consts::DEFAULT_EMBEDDING_MODEL;
|
||||
use public_types::consts::MODEL_SERVER_NAME;
|
||||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use public_types::http::CallArgs;
|
||||
use public_types::http::Client;
|
||||
use public_types::stats::Counter;
|
||||
use public_types::stats::Gauge;
|
||||
use public_types::stats::IncrementingMetric;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct WasmMetrics {
|
||||
pub active_http_calls: Gauge,
|
||||
pub ratelimited_rq: Counter,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
WasmMetrics {
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
ratelimited_rq: Counter::new(String::from("ratelimited_rq")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
|
||||
pub type EmbeddingsStore = HashMap<String, EmbeddingTypeMap>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterCallContext {
|
||||
pub prompt_target_name: String,
|
||||
pub embedding_type: EmbeddingType,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterContext {
|
||||
metrics: Rc<WasmMetrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: RefCell<HashMap<u32, FilterCallContext>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
mode: GatewayMode,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
llm_providers: Option<Rc<LlmProviders>>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
temp_embeddings_store: EmbeddingsStore,
|
||||
}
|
||||
|
||||
impl FilterContext {
|
||||
pub fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
metrics: Rc::new(WasmMetrics::new()),
|
||||
system_prompt: Rc::new(None),
|
||||
prompt_targets: Rc::new(HashMap::new()),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(PromptGuards::default()),
|
||||
mode: GatewayMode::Prompt,
|
||||
llm_providers: None,
|
||||
embeddings_store: Some(Rc::new(HashMap::new())),
|
||||
temp_embeddings_store: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_prompt_targets(&self) {
|
||||
for values in self.prompt_targets.iter() {
|
||||
let prompt_target = values.1;
|
||||
self.schedule_embeddings_call(
|
||||
&prompt_target.name,
|
||||
&prompt_target.description,
|
||||
EmbeddingType::Description,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn schedule_embeddings_call(
|
||||
&self,
|
||||
prompt_target_name: &str,
|
||||
input: &str,
|
||||
embedding_type: EmbeddingType,
|
||||
) {
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(String::from(input))),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
let json_data = serde_json::to_string(&embeddings_input).unwrap();
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
"/embeddings",
|
||||
vec![
|
||||
(ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME),
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", MODEL_SERVER_NAME),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(60),
|
||||
);
|
||||
|
||||
let call_context = crate::filter_context::FilterCallContext {
|
||||
prompt_target_name: String::from(prompt_target_name),
|
||||
embedding_type,
|
||||
};
|
||||
|
||||
if let Err(error) = self.http_call(call_args, call_context) {
|
||||
panic!("{error}")
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding_response_handler(
|
||||
&mut self,
|
||||
body_size: usize,
|
||||
embedding_type: EmbeddingType,
|
||||
prompt_target_name: String,
|
||||
) {
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(&prompt_target_name)
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"Received embeddings response for unknown prompt target name={}",
|
||||
prompt_target_name
|
||||
)
|
||||
});
|
||||
|
||||
let body = self
|
||||
.get_http_call_response_body(0, body_size)
|
||||
.expect("No body in response");
|
||||
if !body.is_empty() {
|
||||
let mut embedding_response: CreateEmbeddingResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error deserializing embedding response. body: {:?}: {:?}",
|
||||
String::from_utf8(body).unwrap(),
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let embeddings = embedding_response.data.remove(0).embedding;
|
||||
debug!(
|
||||
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
|
||||
prompt_target.name,
|
||||
prompt_target.description,
|
||||
embedding_type
|
||||
);
|
||||
|
||||
let entry = self.temp_embeddings_store.entry(prompt_target_name);
|
||||
match entry {
|
||||
Entry::Occupied(_) => {
|
||||
entry.and_modify(|e| {
|
||||
if let Entry::Vacant(e) = e.entry(embedding_type) {
|
||||
e.insert(embeddings);
|
||||
} else {
|
||||
panic!(
|
||||
"Duplicate {:?} for prompt target with name=\"{}\"",
|
||||
&embedding_type, prompt_target.name
|
||||
)
|
||||
}
|
||||
});
|
||||
}
|
||||
Entry::Vacant(_) => {
|
||||
entry.or_insert(HashMap::from([(embedding_type, embeddings)]));
|
||||
}
|
||||
}
|
||||
|
||||
if self.prompt_targets.len() == self.temp_embeddings_store.len() {
|
||||
self.embeddings_store =
|
||||
Some(Rc::new(std::mem::take(&mut self.temp_embeddings_store)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Client for FilterContext {
|
||||
type CallContext = FilterCallContext;
|
||||
|
||||
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
|
||||
&self.callouts
|
||||
}
|
||||
|
||||
fn active_http_calls(&self) -> &Gauge {
|
||||
&self.metrics.active_http_calls
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for FilterContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
debug!(
|
||||
"filter_context: on_http_call_response called with token_id: {:?}",
|
||||
token_id
|
||||
);
|
||||
let callout_data = self
|
||||
.callouts
|
||||
.borrow_mut()
|
||||
.remove(&token_id)
|
||||
.expect("invalid token_id");
|
||||
|
||||
self.metrics.active_http_calls.increment(-1);
|
||||
|
||||
self.embedding_response_handler(
|
||||
body_size,
|
||||
callout_data.embedding_type,
|
||||
callout_data.prompt_target_name,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
let config_bytes = self
|
||||
.get_plugin_configuration()
|
||||
.expect("Arch config cannot be empty");
|
||||
|
||||
let config: Configuration = match serde_yaml::from_slice(&config_bytes) {
|
||||
Ok(config) => config,
|
||||
Err(err) => panic!("Invalid arch config \"{:?}\"", err),
|
||||
};
|
||||
|
||||
self.overrides = Rc::new(config.overrides);
|
||||
|
||||
let mut prompt_targets = HashMap::new();
|
||||
for pt in config.prompt_targets {
|
||||
prompt_targets.insert(pt.name.clone(), pt.clone());
|
||||
}
|
||||
self.system_prompt = Rc::new(config.system_prompt);
|
||||
self.prompt_targets = Rc::new(prompt_targets);
|
||||
self.mode = config.mode.unwrap_or_default();
|
||||
|
||||
ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default()));
|
||||
|
||||
if let Some(prompt_guards) = config.prompt_guards {
|
||||
self.prompt_guards = Rc::new(prompt_guards)
|
||||
}
|
||||
|
||||
match config.llm_providers.try_into() {
|
||||
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
|
||||
Err(err) => panic!("{err}"),
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||
debug!(
|
||||
"||| create_http_context called with context_id: {:?} |||",
|
||||
context_id
|
||||
);
|
||||
|
||||
// No StreamContext can be created until the Embedding Store is fully initialized.
|
||||
let embedding_store = match self.mode {
|
||||
GatewayMode::Llm => None,
|
||||
GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())),
|
||||
};
|
||||
Some(Box::new(StreamContext::new(
|
||||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
Rc::clone(&self.system_prompt),
|
||||
Rc::clone(&self.prompt_targets),
|
||||
Rc::clone(&self.prompt_guards),
|
||||
Rc::clone(&self.overrides),
|
||||
Rc::clone(
|
||||
self.llm_providers
|
||||
.as_ref()
|
||||
.expect("LLM Providers must exist when Streams are being created"),
|
||||
),
|
||||
embedding_store,
|
||||
self.mode.clone(),
|
||||
)))
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Option<ContextType> {
|
||||
Some(ContextType::HttpContext)
|
||||
}
|
||||
|
||||
fn on_vm_start(&mut self, _: usize) -> bool {
|
||||
self.set_tick_period(Duration::from_secs(1));
|
||||
true
|
||||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
debug!("starting up arch filter in mode: {:?}", self.mode);
|
||||
if self.mode == GatewayMode::Prompt {
|
||||
self.process_prompt_targets();
|
||||
}
|
||||
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
}
|
||||
17
crates/prompt_gateway/src/lib.rs
Normal file
17
crates/prompt_gateway/src/lib.rs
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
use filter_context::FilterContext;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
mod filter_context;
|
||||
mod llm_providers;
|
||||
mod ratelimit;
|
||||
mod routing;
|
||||
mod stream_context;
|
||||
mod tokenizer;
|
||||
|
||||
proxy_wasm::main! {{
|
||||
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
|
||||
Box::new(FilterContext::new())
|
||||
});
|
||||
}}
|
||||
69
crates/prompt_gateway/src/llm_providers.rs
Normal file
69
crates/prompt_gateway/src/llm_providers.rs
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
use public_types::configuration::LlmProvider;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LlmProviders {
|
||||
providers: HashMap<String, Rc<LlmProvider>>,
|
||||
default: Option<Rc<LlmProvider>>,
|
||||
}
|
||||
|
||||
impl LlmProviders {
|
||||
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc<LlmProvider>> {
|
||||
self.providers.iter()
|
||||
}
|
||||
|
||||
pub fn default(&self) -> Option<Rc<LlmProvider>> {
|
||||
self.default.as_ref().map(|rc| rc.clone())
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
|
||||
self.providers.get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum LlmProvidersNewError {
|
||||
#[error("There must be at least one LLM Provider")]
|
||||
EmptySource,
|
||||
#[error("There must be at most one default LLM Provider")]
|
||||
MoreThanOneDefault,
|
||||
#[error("\'{0}\' is not a unique name")]
|
||||
DuplicateName(String),
|
||||
}
|
||||
|
||||
impl TryFrom<Vec<LlmProvider>> for LlmProviders {
|
||||
type Error = LlmProvidersNewError;
|
||||
|
||||
fn try_from(llm_providers_config: Vec<LlmProvider>) -> Result<Self, Self::Error> {
|
||||
if llm_providers_config.is_empty() {
|
||||
return Err(LlmProvidersNewError::EmptySource);
|
||||
}
|
||||
|
||||
let mut llm_providers = LlmProviders {
|
||||
providers: HashMap::new(),
|
||||
default: None,
|
||||
};
|
||||
|
||||
for llm_provider in llm_providers_config {
|
||||
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
|
||||
if llm_provider.default.unwrap_or_default() {
|
||||
match llm_providers.default {
|
||||
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
|
||||
None => llm_providers.default = Some(Rc::clone(&llm_provider)),
|
||||
}
|
||||
}
|
||||
|
||||
// Insert and check that there is no other provider with the same name.
|
||||
let name = llm_provider.name.clone();
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(name.clone(), llm_provider)
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
}
|
||||
}
|
||||
Ok(llm_providers)
|
||||
}
|
||||
}
|
||||
450
crates/prompt_gateway/src/ratelimit.rs
Normal file
450
crates/prompt_gateway/src/ratelimit.rs
Normal file
|
|
@ -0,0 +1,450 @@
|
|||
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
|
||||
use log::debug;
|
||||
use public_types::configuration;
|
||||
use public_types::configuration::{Limit, Ratelimit, TimeUnit};
|
||||
use std::fmt::Display;
|
||||
use std::num::{NonZero, NonZeroU32};
|
||||
use std::sync::RwLock;
|
||||
use std::{collections::HashMap, sync::OnceLock};
|
||||
|
||||
pub type RatelimitData = RwLock<RatelimitMap>;
|
||||
|
||||
pub fn ratelimits(ratelimits_config: Option<Vec<Ratelimit>>) -> &'static RatelimitData {
|
||||
static RATELIMIT_DATA: OnceLock<RatelimitData> = OnceLock::new();
|
||||
RATELIMIT_DATA.get_or_init(|| {
|
||||
RwLock::new(RatelimitMap::new(
|
||||
ratelimits_config.expect("The initialization call has to have passed a config"),
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
// The Data Structure is laid out in the following way:
|
||||
// Provider -> Hash { Header -> Limit }.
|
||||
// If the Header used to configure the given Limit:
|
||||
// a) Has None value, then there will be N Limit keyed by the Header value.
|
||||
// b) Has Some() value, then there will be 1 Limit keyed by the empty string.
|
||||
// It would have been nicer to use a non-keyed limit for b). However, the type system made that option a nightmare.
|
||||
pub struct RatelimitMap {
|
||||
datastore: HashMap<String, HashMap<configuration::Header, DefaultKeyedRateLimiter<String>>>,
|
||||
}
|
||||
|
||||
// This version of Header demands that the user passes a header value to match on.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Header {
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
impl Display for Header {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{self:?}")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Header> for configuration::Header {
|
||||
fn from(header: Header) -> Self {
|
||||
Self {
|
||||
key: header.key,
|
||||
value: Some(header.value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error("exceeded limit provider={provider}, selector={selector}, tokens_used={tokens_used}")]
|
||||
ExceededLimit {
|
||||
provider: String,
|
||||
selector: Header,
|
||||
tokens_used: NonZeroU32,
|
||||
},
|
||||
}
|
||||
|
||||
impl RatelimitMap {
|
||||
// n.b new is private so that the only access to the Ratelimits can be done via the static
|
||||
// reference inside a RwLock via ratelimit::ratelimits().
|
||||
fn new(ratelimits_config: Vec<Ratelimit>) -> Self {
|
||||
let mut new_ratelimit_map = RatelimitMap {
|
||||
datastore: HashMap::new(),
|
||||
};
|
||||
for ratelimit_config in ratelimits_config {
|
||||
let limit = DefaultKeyedRateLimiter::keyed(get_quota(ratelimit_config.limit));
|
||||
|
||||
match new_ratelimit_map.datastore.get_mut(&ratelimit_config.model) {
|
||||
Some(limits) => match limits.get_mut(&ratelimit_config.selector) {
|
||||
Some(_) => {
|
||||
panic!("repeated selector. Selectors per provider must be unique")
|
||||
}
|
||||
None => {
|
||||
limits.insert(ratelimit_config.selector, limit);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
// The provider has not been seen before.
|
||||
// Insert the provider and a new HashMap with the specified limit
|
||||
let new_hash_map = HashMap::from([(ratelimit_config.selector, limit)]);
|
||||
new_ratelimit_map
|
||||
.datastore
|
||||
.insert(ratelimit_config.model, new_hash_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
new_ratelimit_map
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn check_limit(
|
||||
&self,
|
||||
provider: String,
|
||||
selector: Header,
|
||||
tokens_used: NonZeroU32,
|
||||
) -> Result<(), Error> {
|
||||
debug!(
|
||||
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
|
||||
provider, selector, tokens_used
|
||||
);
|
||||
|
||||
let provider_limits = match self.datastore.get(&provider) {
|
||||
None => {
|
||||
// No limit configured for this provider, hence ok.
|
||||
return Ok(());
|
||||
}
|
||||
Some(limit) => limit,
|
||||
};
|
||||
|
||||
let mut config_selector = configuration::Header::from(selector.clone());
|
||||
|
||||
let (limit, limit_key) = match provider_limits.get(&config_selector) {
|
||||
// This is a specific limit, i.e one that was configured with both key, and value.
|
||||
// Therefore, the key for the internal limit does not matter, and hence the empty string is always returned.
|
||||
Some(limit) => (limit, String::from("")),
|
||||
None => {
|
||||
// Unwrap is ok here because we _know_ the value exists.
|
||||
let header_key = config_selector.value.take().unwrap();
|
||||
// Search for less specific limit, i.e, one that was configured without a value, therefore every Header
|
||||
// value has its own key in the internal limit.
|
||||
match provider_limits.get(&config_selector) {
|
||||
Some(limit) => (limit, header_key),
|
||||
// No limit for that header key, value pair exists within that provider limits.
|
||||
None => {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match limit.check_key_n(&limit_key, tokens_used) {
|
||||
Ok(Ok(())) => Ok(()),
|
||||
Ok(Err(_)) | Err(InsufficientCapacity(_)) => Err(Error::ExceededLimit {
|
||||
provider,
|
||||
selector,
|
||||
tokens_used,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_quota(limit: Limit) -> Quota {
|
||||
let tokens = NonZero::new(limit.tokens).expect("Limit's tokens must be positive");
|
||||
match limit.unit {
|
||||
TimeUnit::Second => Quota::per_second(tokens),
|
||||
TimeUnit::Minute => Quota::per_minute(tokens),
|
||||
TimeUnit::Hour => Quota::per_hour(tokens),
|
||||
}
|
||||
}
|
||||
|
||||
// The following tests are inside the ratelimit module in order to access RatelimitMap::new() in order to provide
|
||||
// different configuration values per test.
|
||||
#[test]
|
||||
fn non_existent_provider_is_ok() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 100,
|
||||
unit: TimeUnit::Minute,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("non-existent-provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_existent_key_is_ok() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 100,
|
||||
unit: TimeUnit::Minute,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn specific_limit_does_not_catch_non_specific_value() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 200,
|
||||
unit: TimeUnit::Second,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("not-the-correct-value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn specific_limit_is_hit() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 200,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_specific_key_has_different_limits_for_different_values() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 100,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
}];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
// Value1 takes 50.
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("only-key"),
|
||||
value: String::from("value1"),
|
||||
},
|
||||
NonZero::new(50).unwrap(),
|
||||
)
|
||||
.is_ok());
|
||||
|
||||
// value2 takes 60 because it has its own 100 limit
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("only-key"),
|
||||
value: String::from("value2"),
|
||||
},
|
||||
NonZero::new(60).unwrap(),
|
||||
)
|
||||
.is_ok());
|
||||
|
||||
// However value1 cannot take more than 100 per hour which 50+70 = 120
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
Header {
|
||||
key: String::from("only-key"),
|
||||
value: String::from("value1"),
|
||||
},
|
||||
NonZero::new(70).unwrap(),
|
||||
)
|
||||
.is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_provider_can_have_different_limits_with_the_same_keys() {
|
||||
let ratelimits_config = vec![
|
||||
Ratelimit {
|
||||
model: String::from("first_provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 100,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
},
|
||||
Ratelimit {
|
||||
model: String::from("second_provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 200,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
let ratelimits = RatelimitMap::new(ratelimits_config);
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("first_provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(100).unwrap(),
|
||||
)
|
||||
.is_ok());
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("second_provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(200).unwrap(),
|
||||
)
|
||||
.is_ok());
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("first_provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(1).unwrap(),
|
||||
)
|
||||
.is_err());
|
||||
|
||||
assert!(ratelimits
|
||||
.check_limit(
|
||||
String::from("second_provider"),
|
||||
Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(1).unwrap(),
|
||||
)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
// These tests use the publicly exposed static singleton, thus the same configuration is used in every test.
|
||||
// If more tests are written here, move the initial call out of the test.
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::ratelimits;
|
||||
use configuration::{Limit, Ratelimit, TimeUnit};
|
||||
use public_types::configuration;
|
||||
use std::num::NonZero;
|
||||
use std::thread;
|
||||
|
||||
#[test]
|
||||
fn make_ratelimits_optional() {
|
||||
let ratelimits_config = Vec::new();
|
||||
|
||||
// Initialize in the main thread.
|
||||
ratelimits(Some(ratelimits_config));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_threads_have_same_ratelimit_data_structure() {
|
||||
let ratelimits_config = Some(vec![Ratelimit {
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
},
|
||||
limit: Limit {
|
||||
tokens: 200,
|
||||
unit: TimeUnit::Hour,
|
||||
},
|
||||
}]);
|
||||
|
||||
// Initialize in the main thread.
|
||||
ratelimits(ratelimits_config);
|
||||
|
||||
// Use the singleton in a different thread.
|
||||
thread::spawn(|| {
|
||||
let ratelimits = ratelimits(None);
|
||||
|
||||
assert!(ratelimits
|
||||
.read()
|
||||
.unwrap()
|
||||
.check_limit(
|
||||
String::from("provider"),
|
||||
super::Header {
|
||||
key: String::from("key"),
|
||||
value: String::from("value"),
|
||||
},
|
||||
NonZero::new(5000).unwrap(),
|
||||
)
|
||||
.is_err())
|
||||
});
|
||||
}
|
||||
}
|
||||
50
crates/prompt_gateway/src/routing.rs
Normal file
50
crates/prompt_gateway/src/routing.rs
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
use std::rc::Rc;
|
||||
|
||||
use crate::llm_providers::LlmProviders;
|
||||
use log::debug;
|
||||
use public_types::configuration::LlmProvider;
|
||||
use rand::{seq::IteratorRandom, thread_rng};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ProviderHint {
|
||||
Default,
|
||||
Name(String),
|
||||
}
|
||||
|
||||
impl From<String> for ProviderHint {
|
||||
fn from(value: String) -> Self {
|
||||
match value.as_str() {
|
||||
"default" => ProviderHint::Default,
|
||||
_ => ProviderHint::Name(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_llm_provider(
|
||||
llm_providers: &LlmProviders,
|
||||
provider_hint: Option<ProviderHint>,
|
||||
) -> Rc<LlmProvider> {
|
||||
let maybe_provider = provider_hint.and_then(|hint| match hint {
|
||||
ProviderHint::Default => llm_providers.default(),
|
||||
// FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST?
|
||||
ProviderHint::Name(name) => llm_providers.get(&name),
|
||||
});
|
||||
|
||||
if let Some(provider) = maybe_provider {
|
||||
return provider;
|
||||
}
|
||||
|
||||
if llm_providers.default().is_some() {
|
||||
debug!("no llm provider found for hint, using default llm provider");
|
||||
return llm_providers.default().unwrap();
|
||||
}
|
||||
|
||||
debug!("no default llm found, using random llm provider");
|
||||
let mut rng = thread_rng();
|
||||
llm_providers
|
||||
.iter()
|
||||
.choose(&mut rng)
|
||||
.expect("There should always be at least one llm provider")
|
||||
.1
|
||||
.clone()
|
||||
}
|
||||
1576
crates/prompt_gateway/src/stream_context.rs
Normal file
1576
crates/prompt_gateway/src/stream_context.rs
Normal file
File diff suppressed because it is too large
Load diff
39
crates/prompt_gateway/src/tokenizer.rs
Normal file
39
crates/prompt_gateway/src/tokenizer.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
use log::debug;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
#[allow(dead_code)]
|
||||
pub enum Error {
|
||||
UnknownModel,
|
||||
FailedToTokenize,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
|
||||
debug!("getting token count model={}", model_name);
|
||||
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel)?;
|
||||
Ok(bpe.encode_ordinary(text).len())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn encode_ordinary() {
|
||||
let model_name = "gpt-3.5-turbo";
|
||||
let text = "How many tokens does this sentence have?";
|
||||
assert_eq!(
|
||||
8,
|
||||
token_count(model_name, text).expect("correct tokenization")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unrecognized_model() {
|
||||
assert_eq!(
|
||||
Error::UnknownModel,
|
||||
token_count("unknown", "").expect_err("unknown model")
|
||||
)
|
||||
}
|
||||
}
|
||||
805
crates/prompt_gateway/tests/integration.rs
Normal file
805
crates/prompt_gateway/tests/integration.rs
Normal file
|
|
@ -0,0 +1,805 @@
|
|||
use http::StatusCode;
|
||||
use proxy_wasm_test_framework::tester::{self, Tester};
|
||||
use proxy_wasm_test_framework::types::{
|
||||
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
|
||||
};
|
||||
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
|
||||
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
|
||||
use public_types::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
|
||||
use public_types::embeddings::{
|
||||
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
|
||||
Embedding,
|
||||
};
|
||||
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
|
||||
use serde_yaml::Value;
|
||||
use serial_test::serial;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
fn wasm_module() -> String {
|
||||
let wasm_file = Path::new("target/wasm32-wasi/release/intelligent_prompt_gateway.wasm");
|
||||
assert!(
|
||||
wasm_file.exists(),
|
||||
"Run `cargo build --release --target=wasm32-wasi` first"
|
||||
);
|
||||
wasm_file.to_str().unwrap().to_string()
|
||||
}
|
||||
|
||||
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
||||
module
|
||||
.call_proxy_on_request_headers(http_context, 0, false)
|
||||
.expect_get_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-llm-provider-hint"),
|
||||
)
|
||||
.returning(Some("default"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-upstream"),
|
||||
Some("arch_llm_listener"),
|
||||
)
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-llm-provider"),
|
||||
Some("open-ai-gpt-4"),
|
||||
)
|
||||
.expect_replace_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("Authorization"),
|
||||
Some("Bearer secret_key"),
|
||||
)
|
||||
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
|
||||
.expect_get_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-ratelimit-selector"),
|
||||
)
|
||||
.returning(Some("selector-key"))
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("selector-key"))
|
||||
.returning(Some("selector-value"))
|
||||
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
|
||||
.returning(None)
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
||||
.returning(Some("/v1/chat/completions"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
|
||||
.returning(None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
request_headers_expectations(module, http_context);
|
||||
|
||||
// Request Body
|
||||
let chat_completions_request_body = "\
|
||||
{\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
],\
|
||||
\"model\": \"gpt-4\"\
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/guard"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(1))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
|
||||
let prompt_guard_response = PromptGuardResponse {
|
||||
toxic_prob: None,
|
||||
toxic_verdict: None,
|
||||
jailbreak_prob: None,
|
||||
jailbreak_verdict: None,
|
||||
};
|
||||
let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
1,
|
||||
0,
|
||||
prompt_guard_response_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&prompt_guard_response_buffer))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(2))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let embedding_response = CreateEmbeddingResponse {
|
||||
data: vec![Embedding {
|
||||
index: 0,
|
||||
embedding: vec![],
|
||||
object: embedding::Object::default(),
|
||||
}],
|
||||
model: String::from("test"),
|
||||
object: create_embedding_response::Object::default(),
|
||||
usage: Box::new(CreateEmbeddingResponseUsage::new(0, 0)),
|
||||
};
|
||||
let embeddings_response_buffer = serde_json::to_string(&embedding_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
2,
|
||||
0,
|
||||
embeddings_response_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&embeddings_response_buffer))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/zeroshot"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(3))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let zero_shot_response = ZeroShotClassificationResponse {
|
||||
predicted_class: "weather_forecast".to_string(),
|
||||
predicted_class_score: 0.1,
|
||||
scores: HashMap::new(),
|
||||
model: "test-model".to_string(),
|
||||
};
|
||||
let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
3,
|
||||
0,
|
||||
zeroshot_intent_detection_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&zeroshot_intent_detection_buffer))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
("x-arch-upstream", "arch_fc"),
|
||||
(":path", "/v1/chat/completions"),
|
||||
(":authority", "arch_fc"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "120000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(4))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn setup_filter(module: &mut Tester, config: &str) -> i32 {
|
||||
let filter_context = 1;
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(filter_context, 0)
|
||||
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
|
||||
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_configure(filter_context, config.len() as i32)
|
||||
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
|
||||
.returning(Some(config))
|
||||
.execute_and_expect(ReturnType::Bool(true))
|
||||
.unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_tick(filter_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(101))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_set_tick_period_millis(Some(0))
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let embedding_response = CreateEmbeddingResponse {
|
||||
data: vec![Embedding {
|
||||
embedding: vec![],
|
||||
index: 0,
|
||||
object: embedding::Object::default(),
|
||||
}],
|
||||
model: String::from("test"),
|
||||
object: create_embedding_response::Object::default(),
|
||||
usage: Box::new(CreateEmbeddingResponseUsage {
|
||||
prompt_tokens: 0,
|
||||
total_tokens: 0,
|
||||
}),
|
||||
};
|
||||
let embedding_response_str = serde_json::to_string(&embedding_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
filter_context,
|
||||
101,
|
||||
0,
|
||||
embedding_response_str.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_log(
|
||||
Some(LogLevel::Debug),
|
||||
Some(
|
||||
format!(
|
||||
"filter_context: on_http_call_response called with token_id: {:?}",
|
||||
101
|
||||
)
|
||||
.as_str(),
|
||||
),
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&embedding_response_str))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
filter_context
|
||||
}
|
||||
|
||||
fn default_config() -> &'static str {
|
||||
r#"
|
||||
version: "0.1-beta"
|
||||
|
||||
listener:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
endpoints:
|
||||
api_server:
|
||||
endpoint: api_server:80
|
||||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: open-ai-gpt-4
|
||||
provider: openai
|
||||
access_key: secret_key
|
||||
model: gpt-4
|
||||
default: true
|
||||
|
||||
overrides:
|
||||
# confidence threshold for prompt target intent matching
|
||||
prompt_target_intent_matching_threshold: 0.6
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception:
|
||||
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
|
||||
|
||||
prompt_targets:
|
||||
- name: weather_forecast
|
||||
description: This function provides realtime weather forecast information for a given city.
|
||||
parameters:
|
||||
- name: city
|
||||
required: true
|
||||
description: The city for which the weather forecast is requested.
|
||||
- name: days
|
||||
description: The number of days for which the weather forecast is requested.
|
||||
- name: units
|
||||
description: The units in which the weather forecast is requested.
|
||||
endpoint:
|
||||
name: api_server
|
||||
path: /weather
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
|
||||
ratelimits:
|
||||
- model: gpt-4
|
||||
selector:
|
||||
key: selector-key
|
||||
value: selector-value
|
||||
limit:
|
||||
tokens: 1
|
||||
unit: minute
|
||||
"#
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn successful_request_to_open_ai_chat_completions() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
request_headers_expectations(&mut module, http_context);
|
||||
|
||||
// Request Body
|
||||
let chat_completions_request_body = "\
|
||||
{\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
],\
|
||||
\"model\": \"gpt-4\"\
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("arch_internal"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn bad_request_to_open_ai_chat_completions() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
request_headers_expectations(&mut module, http_context);
|
||||
|
||||
// Request Body
|
||||
let incomplete_chat_completions_request_body = "\
|
||||
{\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
]\
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
incomplete_chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(incomplete_chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn request_ratelimited() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let arch_fc_resp = ChatCompletionsResponse {
|
||||
usage: Some(Usage {
|
||||
completion_tokens: 0,
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: "test".to_string(),
|
||||
index: 0,
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: String::from("test"),
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("city"),
|
||||
Value::String(String::from("seattle")),
|
||||
)]),
|
||||
},
|
||||
}]),
|
||||
model: None,
|
||||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&arch_fc_resp_str))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/hallucination"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let hallucatination_body = HallucinationClassificationResponse {
|
||||
params_scores: HashMap::from([("city".to_string(), 0.99)]),
|
||||
model: "nli-model".to_string(),
|
||||
};
|
||||
|
||||
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "api_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/weather"),
|
||||
(":authority", "api_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(6))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
|
||||
.returning(Some("200"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.expect_metric_increment("ratelimited_rq", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn request_not_ratelimited() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap();
|
||||
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
|
||||
let config_str = serde_json::to_string(&config).unwrap();
|
||||
|
||||
let filter_context = setup_filter(&mut module, &config_str);
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let arch_fc_resp = ChatCompletionsResponse {
|
||||
usage: Some(Usage {
|
||||
completion_tokens: 0,
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: "test".to_string(),
|
||||
index: 0,
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: String::from("test"),
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("city"),
|
||||
Value::String(String::from("seattle")),
|
||||
)]),
|
||||
},
|
||||
}]),
|
||||
model: None,
|
||||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&arch_fc_resp_str))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/hallucination"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// hallucination should return that parameters were not halliucinated
|
||||
// prompt: str
|
||||
// parameters: dict
|
||||
// model: str
|
||||
|
||||
let hallucatination_body = HallucinationClassificationResponse {
|
||||
params_scores: HashMap::from([("city".to_string(), 0.99)]),
|
||||
model: "nli-model".to_string(),
|
||||
};
|
||||
|
||||
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "api_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/weather"),
|
||||
(":authority", "api_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(6))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
|
||||
.returning(Some("200"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
382
crates/public_types/Cargo.lock
generated
Normal file
382
crates/public_types/Cargo.lock
generated
Normal file
|
|
@ -0,0 +1,382 @@
|
|||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8fd72866655d1904d6b0997d0b07ba561047d070fbe29de039031c641b61217"
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.8.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "allocator-api2"
|
||||
version = "0.2.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "derivative"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "diff"
|
||||
version = "0.1.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"
|
||||
|
||||
[[package]]
|
||||
name = "duration-string"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fcc1d9ae294a15ed05aeae8e11ee5f2b3fe971c077d45a42fb20825fba6ee13"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
|
||||
|
||||
[[package]]
|
||||
name = "governor"
|
||||
version = "0.6.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"no-std-compat",
|
||||
"nonzero_ext",
|
||||
"portable-atomic",
|
||||
"smallvec",
|
||||
"spinning_top",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e91b62f79061a0bc2e046024cb7ba44b08419ed238ecbd9adbd787434b9e8c25"
|
||||
dependencies = [
|
||||
"ahash 0.3.8",
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
||||
dependencies = [
|
||||
"ahash 0.8.11",
|
||||
"allocator-api2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown 0.14.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
||||
|
||||
[[package]]
|
||||
name = "no-std-compat"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
|
||||
dependencies = [
|
||||
"hashbrown 0.8.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nonzero_ext"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.20.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "1.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
|
||||
|
||||
[[package]]
|
||||
name = "pretty_assertions"
|
||||
version = "1.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d"
|
||||
dependencies = [
|
||||
"diff",
|
||||
"yansi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.86"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proxy-wasm"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "14a5a4df5a1ab77235e36a0a0f638687ee1586d21ee9774037693001e94d4e11"
|
||||
dependencies = [
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "public_types"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"derivative",
|
||||
"duration-string",
|
||||
"governor",
|
||||
"log",
|
||||
"pretty_assertions",
|
||||
"proxy-wasm",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.210"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.210"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.128"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_yaml"
|
||||
version = "0.9.34+deprecated"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
"unsafe-libyaml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
|
||||
|
||||
[[package]]
|
||||
name = "spinning_top"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.77"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.64"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.64"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
|
||||
|
||||
[[package]]
|
||||
name = "unsafe-libyaml"
|
||||
version = "0.2.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||
|
||||
[[package]]
|
||||
name = "yansi"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.7.35"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
|
||||
dependencies = [
|
||||
"zerocopy-derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy-derive"
|
||||
version = "0.7.35"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.77",
|
||||
]
|
||||
18
crates/public_types/Cargo.toml
Normal file
18
crates/public_types/Cargo.toml
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
[package]
|
||||
name = "public_types"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_yaml = "0.9.34"
|
||||
duration-string = { version = "0.3.0", features = ["serde"] }
|
||||
proxy-wasm = "0.2.1"
|
||||
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
|
||||
log = "0.4"
|
||||
derivative = "2.2.0"
|
||||
thiserror = "1.0.64"
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = "1.4.1"
|
||||
serde_json = "1.0.64"
|
||||
448
crates/public_types/src/common_types.rs
Normal file
448
crates/public_types/src/common_types.rs
Normal file
|
|
@ -0,0 +1,448 @@
|
|||
use crate::configuration::PromptTarget;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
pub prompt_target: PromptTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
pub enum EmbeddingType {
|
||||
Name,
|
||||
Description,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorPoint {
|
||||
pub id: String,
|
||||
pub payload: HashMap<String, String>,
|
||||
pub vector: Vec<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StoreVectorEmbeddingsRequest {
|
||||
pub points: Vec<VectorPoint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointResult {
|
||||
pub id: String,
|
||||
pub version: i32,
|
||||
pub score: f64,
|
||||
pub payload: HashMap<String, String>,
|
||||
}
|
||||
|
||||
pub mod open_ai {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use serde_yaml::Value;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsRequest {
|
||||
#[serde(default)]
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<ChatCompletionTool>>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ToolType {
|
||||
#[serde(rename = "function")]
|
||||
Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionTool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: ToolType,
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FunctionParameters {
|
||||
pub properties: HashMap<String, FunctionParameter>,
|
||||
}
|
||||
|
||||
impl Serialize for FunctionParameters {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
// select all requried parameters
|
||||
let required: Vec<&String> = self
|
||||
.properties
|
||||
.iter()
|
||||
.filter(|(_, v)| v.required.unwrap_or(false))
|
||||
.map(|(k, _)| k)
|
||||
.collect();
|
||||
let mut map = serializer.serialize_map(Some(2))?;
|
||||
map.serialize_entry("properties", &self.properties)?;
|
||||
if !required.is_empty() {
|
||||
map.serialize_entry("required", &required)?;
|
||||
}
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FunctionParameter {
|
||||
#[serde(rename = "type")]
|
||||
#[serde(default = "ParameterType::string")]
|
||||
pub parameter_type: ParameterType,
|
||||
pub description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "enum")]
|
||||
pub enum_values: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<String>,
|
||||
}
|
||||
|
||||
impl Serialize for FunctionParameter {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut map = serializer.serialize_map(Some(5))?;
|
||||
map.serialize_entry("type", &self.parameter_type)?;
|
||||
map.serialize_entry("description", &self.description)?;
|
||||
if let Some(enum_values) = &self.enum_values {
|
||||
map.serialize_entry("enum", enum_values)?;
|
||||
}
|
||||
if let Some(default) = &self.default {
|
||||
map.serialize_entry("default", default)?;
|
||||
}
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum ParameterType {
|
||||
#[serde(rename = "int")]
|
||||
Int,
|
||||
#[serde(rename = "float")]
|
||||
Float,
|
||||
#[serde(rename = "bool")]
|
||||
Bool,
|
||||
#[serde(rename = "str")]
|
||||
String,
|
||||
#[serde(rename = "list")]
|
||||
List,
|
||||
#[serde(rename = "dict")]
|
||||
Dict,
|
||||
}
|
||||
|
||||
impl From<String> for ParameterType {
|
||||
fn from(s: String) -> Self {
|
||||
match s.as_str() {
|
||||
"int" => ParameterType::Int,
|
||||
"integer" => ParameterType::Int,
|
||||
"float" => ParameterType::Float,
|
||||
"bool" => ParameterType::Bool,
|
||||
"boolean" => ParameterType::Bool,
|
||||
"str" => ParameterType::String,
|
||||
"string" => ParameterType::String,
|
||||
"list" => ParameterType::List,
|
||||
"array" => ParameterType::List,
|
||||
"dict" => ParameterType::Dict,
|
||||
"dictionary" => ParameterType::Dict,
|
||||
_ => ParameterType::String,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ParameterType {
|
||||
pub fn string() -> ParameterType {
|
||||
ParameterType::String
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StreamOptions {
|
||||
pub include_usage: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub finish_reason: String,
|
||||
pub index: usize,
|
||||
pub message: Message,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: ToolType,
|
||||
pub function: FunctionCallDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCallDetail {
|
||||
pub name: String,
|
||||
pub arguments: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ToolCallState {
|
||||
pub key: String,
|
||||
pub message: Option<Message>,
|
||||
pub tool_call: FunctionCallDetail,
|
||||
pub tool_response: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ArchState {
|
||||
ToolCall(Vec<ToolCallState>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsResponse {
|
||||
pub usage: Option<Usage>,
|
||||
pub choices: Vec<Choice>,
|
||||
pub model: String,
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub completion_tokens: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionChunkResponse {
|
||||
pub model: String,
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChunkChoice {
|
||||
pub delta: Delta,
|
||||
// TODO: could this be an enum?
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Delta {
|
||||
pub content: Option<String>,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ZeroShotClassificationRequest {
|
||||
pub input: String,
|
||||
pub labels: Vec<String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ZeroShotClassificationResponse {
|
||||
pub predicted_class: String,
|
||||
pub predicted_class_score: f64,
|
||||
pub scores: HashMap<String, f64>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationRequest {
|
||||
pub prompt: String,
|
||||
pub parameters: HashMap<String, String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationResponse {
|
||||
pub params_scores: HashMap<String, f64>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum PromptGuardTask {
|
||||
#[serde(rename = "jailbreak")]
|
||||
Jailbreak,
|
||||
#[serde(rename = "toxicity")]
|
||||
Toxicity,
|
||||
#[serde(rename = "both")]
|
||||
Both,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptGuardRequest {
|
||||
pub input: String,
|
||||
pub task: PromptGuardTask,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptGuardResponse {
|
||||
pub toxic_prob: Option<f64>,
|
||||
pub jailbreak_prob: Option<f64>,
|
||||
pub toxic_verdict: Option<bool>,
|
||||
pub jailbreak_verdict: Option<bool>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::common_types::open_ai::Message;
|
||||
use pretty_assertions::{assert_eq, assert_ne};
|
||||
use std::collections::HashMap;
|
||||
|
||||
const TOOL_SERIALIZED: &str = r#"{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What city do you want to know the weather for?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"description": "function to retrieve weather forecast",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "str",
|
||||
"description": "city for weather forecast",
|
||||
"default": "test"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"city"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"stream": true,
|
||||
"stream_options": {
|
||||
"include_usage": true
|
||||
}
|
||||
}"#;
|
||||
|
||||
#[test]
|
||||
fn test_tool_type_request() {
|
||||
use super::open_ai::{
|
||||
ChatCompletionsRequest, FunctionDefinition, FunctionParameter, ParameterType, ToolType,
|
||||
};
|
||||
|
||||
let mut properties = HashMap::new();
|
||||
properties.insert(
|
||||
"city".to_string(),
|
||||
FunctionParameter {
|
||||
parameter_type: ParameterType::String,
|
||||
description: "city for weather forecast".to_string(),
|
||||
required: Some(true),
|
||||
enum_values: None,
|
||||
default: Some("test".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
let function_definition = FunctionDefinition {
|
||||
name: "weather_forecast".to_string(),
|
||||
description: "function to retrieve weather forecast".to_string(),
|
||||
parameters: super::open_ai::FunctionParameters { properties },
|
||||
};
|
||||
|
||||
let chat_completions_request = ChatCompletionsRequest {
|
||||
model: "gpt-3.5-turbo".to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What city do you want to know the weather for?".to_string()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
}],
|
||||
tools: Some(vec![super::open_ai::ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
function: function_definition,
|
||||
}]),
|
||||
stream: true,
|
||||
stream_options: Some(super::open_ai::StreamOptions {
|
||||
include_usage: true,
|
||||
}),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();
|
||||
println!("{}", serialized);
|
||||
assert_eq!(TOOL_SERIALIZED, serialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parameter_types() {
|
||||
use super::open_ai::{
|
||||
ChatCompletionsRequest, FunctionDefinition, FunctionParameter, ParameterType, ToolType,
|
||||
};
|
||||
|
||||
const PARAMETER_SERIALZIED: &str = r#"{
|
||||
"city": {
|
||||
"type": "str",
|
||||
"description": "city for weather forecast",
|
||||
"default": "test"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let properties = HashMap::from([(
|
||||
"city".to_string(),
|
||||
FunctionParameter {
|
||||
parameter_type: ParameterType::String,
|
||||
description: "city for weather forecast".to_string(),
|
||||
required: Some(true),
|
||||
enum_values: None,
|
||||
default: Some("test".to_string()),
|
||||
},
|
||||
)]);
|
||||
|
||||
let serialized = serde_json::to_string_pretty(&properties).unwrap();
|
||||
assert_eq!(PARAMETER_SERIALZIED, serialized);
|
||||
|
||||
// ensure that if type is missing it is set to string
|
||||
const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#"
|
||||
{
|
||||
"city": {
|
||||
"description": "city for weather forecast"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let missing_type_deserialized: HashMap<String, FunctionParameter> =
|
||||
serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap();
|
||||
println!("{:?}", missing_type_deserialized);
|
||||
assert_eq!(
|
||||
missing_type_deserialized
|
||||
.get("city")
|
||||
.unwrap()
|
||||
.parameter_type,
|
||||
ParameterType::String
|
||||
);
|
||||
}
|
||||
}
|
||||
308
crates/public_types/src/configuration.rs
Normal file
308
crates/public_types/src/configuration.rs
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
use duration_string::DurationString;
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use std::fmt::Display;
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Overrides {
|
||||
pub prompt_target_intent_matching_threshold: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Tracing {
|
||||
pub sampling_rate: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub enum GatewayMode {
|
||||
#[serde(rename = "llm")]
|
||||
Llm,
|
||||
#[serde(rename = "prompt")]
|
||||
Prompt,
|
||||
}
|
||||
|
||||
impl Default for GatewayMode {
|
||||
fn default() -> Self {
|
||||
GatewayMode::Prompt
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub version: String,
|
||||
pub listener: Listener,
|
||||
pub endpoints: HashMap<String, Endpoint>,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub overrides: Option<Overrides>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub prompt_guards: Option<PromptGuards>,
|
||||
pub prompt_targets: Vec<PromptTarget>,
|
||||
pub error_target: Option<ErrorTargetDetail>,
|
||||
pub ratelimits: Option<Vec<Ratelimit>>,
|
||||
pub tracing: Option<Tracing>,
|
||||
pub mode: Option<GatewayMode>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ErrorTargetDetail {
|
||||
pub endpoint: Option<EndpointDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Listener {
|
||||
pub address: String,
|
||||
pub port: u16,
|
||||
pub message_format: MessageFormat,
|
||||
// pub connect_timeout: Option<DurationString>,
|
||||
}
|
||||
|
||||
impl Default for Listener {
|
||||
fn default() -> Self {
|
||||
Listener {
|
||||
address: "".to_string(),
|
||||
port: 0,
|
||||
message_format: MessageFormat::default(),
|
||||
// connect_timeout: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub enum MessageFormat {
|
||||
#[serde(rename = "huggingface")]
|
||||
#[default]
|
||||
Huggingface,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PromptGuards {
|
||||
pub input_guards: HashMap<GuardType, GuardOptions>,
|
||||
}
|
||||
|
||||
impl PromptGuards {
|
||||
pub fn jailbreak_on_exception_message(&self) -> Option<&str> {
|
||||
self.input_guards
|
||||
.get(&GuardType::Jailbreak)?
|
||||
.on_exception
|
||||
.as_ref()?
|
||||
.message
|
||||
.as_ref()?
|
||||
.as_str()
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub enum GuardType {
|
||||
#[serde(rename = "jailbreak")]
|
||||
Jailbreak,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GuardOptions {
|
||||
pub on_exception: Option<OnExceptionDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OnExceptionDetails {
|
||||
pub forward_to_error_target: Option<bool>,
|
||||
pub error_handler: Option<String>,
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmRatelimit {
|
||||
pub selector: LlmRatelimitSelector,
|
||||
pub limit: Limit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmRatelimitSelector {
|
||||
pub http_header: Option<RatelimitHeader>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub struct Header {
|
||||
pub key: String,
|
||||
pub value: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Ratelimit {
|
||||
pub model: String,
|
||||
pub selector: Header,
|
||||
pub limit: Limit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Limit {
|
||||
pub tokens: u32,
|
||||
pub unit: TimeUnit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum TimeUnit {
|
||||
#[serde(rename = "second")]
|
||||
Second,
|
||||
#[serde(rename = "minute")]
|
||||
Minute,
|
||||
#[serde(rename = "hour")]
|
||||
Hour,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub struct RatelimitHeader {
|
||||
pub name: String,
|
||||
pub value: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct EmbeddingProviver {
|
||||
pub name: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct LlmProvider {
|
||||
pub name: String,
|
||||
pub provider: String,
|
||||
pub access_key: Option<String>,
|
||||
pub model: String,
|
||||
pub default: Option<bool>,
|
||||
pub stream: Option<bool>,
|
||||
pub rate_limits: Option<LlmRatelimit>,
|
||||
}
|
||||
|
||||
impl Display for LlmProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Endpoint {
|
||||
pub endpoint: Option<String>,
|
||||
// pub connect_timeout: Option<DurationString>,
|
||||
// pub timeout: Option<DurationString>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Parameter {
|
||||
pub name: String,
|
||||
#[serde(rename = "type")]
|
||||
pub parameter_type: Option<String>,
|
||||
pub description: String,
|
||||
pub required: Option<bool>,
|
||||
#[serde(rename = "enum")]
|
||||
pub enum_values: Option<Vec<String>>,
|
||||
pub default: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EndpointDetails {
|
||||
pub name: String,
|
||||
pub path: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptTarget {
|
||||
pub name: String,
|
||||
pub default: Option<bool>,
|
||||
pub description: String,
|
||||
pub endpoint: Option<EndpointDetails>,
|
||||
pub parameters: Option<Vec<Parameter>>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub auto_llm_dispatch_on_response: Option<bool>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::fs;
|
||||
|
||||
use crate::configuration::GuardType;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_configuration() {
|
||||
let ref_config =
|
||||
fs::read_to_string("../docs/source/resources/includes/arch_config_full_reference.yaml")
|
||||
.expect("reference config file not found");
|
||||
|
||||
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
|
||||
assert_eq!(config.version, "v0.1");
|
||||
|
||||
let open_ai_provider = config
|
||||
.llm_providers
|
||||
.iter()
|
||||
.find(|p| p.name.to_lowercase() == "openai")
|
||||
.unwrap();
|
||||
assert_eq!(open_ai_provider.name.to_lowercase(), "openai");
|
||||
assert_eq!(
|
||||
open_ai_provider.access_key,
|
||||
Some("OPENAI_API_KEY".to_string())
|
||||
);
|
||||
assert_eq!(open_ai_provider.model, "gpt-4o");
|
||||
assert_eq!(open_ai_provider.default, Some(true));
|
||||
assert_eq!(open_ai_provider.stream, Some(true));
|
||||
|
||||
let prompt_guards = config.prompt_guards.as_ref().unwrap();
|
||||
let input_guards = &prompt_guards.input_guards;
|
||||
let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap();
|
||||
assert_eq!(
|
||||
jailbreak_guard
|
||||
.on_exception
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.forward_to_error_target,
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
jailbreak_guard.on_exception.as_ref().unwrap().error_handler,
|
||||
None
|
||||
);
|
||||
|
||||
let prompt_targets = &config.prompt_targets;
|
||||
assert_eq!(prompt_targets.len(), 2);
|
||||
let prompt_target = prompt_targets
|
||||
.iter()
|
||||
.find(|p| p.name == "reboot_network_device")
|
||||
.unwrap();
|
||||
assert_eq!(prompt_target.name, "reboot_network_device");
|
||||
assert_eq!(prompt_target.default, None);
|
||||
|
||||
let prompt_target = prompt_targets
|
||||
.iter()
|
||||
.find(|p| p.name == "information_extraction")
|
||||
.unwrap();
|
||||
assert_eq!(prompt_target.name, "information_extraction");
|
||||
assert_eq!(prompt_target.default, Some(true));
|
||||
assert_eq!(
|
||||
prompt_target.endpoint.as_ref().unwrap().name,
|
||||
"app_server".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
prompt_target.endpoint.as_ref().unwrap().path,
|
||||
Some("/agent/summary".to_string())
|
||||
);
|
||||
|
||||
let error_target = config.error_target.as_ref().unwrap();
|
||||
assert_eq!(
|
||||
error_target.endpoint.as_ref().unwrap().name,
|
||||
"error_target_1".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
error_target.endpoint.as_ref().unwrap().path,
|
||||
Some("/error".to_string())
|
||||
);
|
||||
|
||||
let tracing = config.tracing.as_ref().unwrap();
|
||||
assert_eq!(tracing.sampling_rate.unwrap(), 0.1);
|
||||
|
||||
let mode = config
|
||||
.mode
|
||||
.as_ref()
|
||||
.unwrap_or(&super::GatewayMode::Prompt);
|
||||
assert_eq!(*mode, super::GatewayMode::Prompt);
|
||||
}
|
||||
}
|
||||
22
crates/public_types/src/consts.rs
Normal file
22
crates/public_types/src/consts.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "katanemo/bge-large-en-v1.5";
|
||||
pub const DEFAULT_INTENT_MODEL: &str = "katanemo/bart-large-mnli";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
|
||||
pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25;
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
|
||||
pub const ARC_FC_CLUSTER: &str = "arch_fc";
|
||||
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
||||
pub const MODEL_SERVER_NAME: &str = "model_server";
|
||||
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
||||
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions";
|
||||
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
|
||||
pub const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";
|
||||
pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream";
|
||||
pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";
|
||||
pub const ARCH_MODEL_PREFIX: &str = "Arch";
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use crate::embeddings;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingRequest {
|
||||
#[serde(rename = "input")]
|
||||
pub input: Box<embeddings::CreateEmbeddingRequestInput>,
|
||||
/// ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.
|
||||
#[serde(rename = "model")]
|
||||
pub model: String,
|
||||
/// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
|
||||
#[serde(rename = "encoding_format", skip_serializing_if = "Option::is_none")]
|
||||
pub encoding_format: Option<EncodingFormat>,
|
||||
/// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.
|
||||
#[serde(rename = "dimensions", skip_serializing_if = "Option::is_none")]
|
||||
pub dimensions: Option<i32>,
|
||||
/// A unique identifier representing your end-user, which can help to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).
|
||||
#[serde(rename = "user", skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
impl CreateEmbeddingRequest {
|
||||
pub fn new(
|
||||
input: embeddings::CreateEmbeddingRequestInput,
|
||||
model: String,
|
||||
) -> CreateEmbeddingRequest {
|
||||
CreateEmbeddingRequest {
|
||||
input: Box::new(input),
|
||||
model,
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
pub enum EncodingFormat {
|
||||
#[serde(rename = "float")]
|
||||
Float,
|
||||
#[serde(rename = "base64")]
|
||||
Base64,
|
||||
}
|
||||
|
||||
impl Default for EncodingFormat {
|
||||
fn default() -> EncodingFormat {
|
||||
Self::Float
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use crate::embeddings;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// CreateEmbeddingRequestInput : Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens.
|
||||
/// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CreateEmbeddingRequestInput {
|
||||
/// The string that will be turned into an embedding.
|
||||
String(String),
|
||||
/// The array of integers that will be turned into an embedding.
|
||||
Array(Vec<i32>),
|
||||
}
|
||||
|
||||
impl Default for CreateEmbeddingRequestInput {
|
||||
fn default() -> Self {
|
||||
Self::String(Default::default())
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use crate::embeddings;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingResponse {
|
||||
/// The list of embeddings generated by the model.
|
||||
#[serde(rename = "data")]
|
||||
pub data: Vec<embeddings::Embedding>,
|
||||
/// The name of the model used to generate the embedding.
|
||||
#[serde(rename = "model")]
|
||||
pub model: String,
|
||||
/// The object type, which is always \"list\".
|
||||
#[serde(rename = "object")]
|
||||
pub object: Object,
|
||||
#[serde(rename = "usage")]
|
||||
pub usage: Box<embeddings::CreateEmbeddingResponseUsage>,
|
||||
}
|
||||
|
||||
impl CreateEmbeddingResponse {
|
||||
pub fn new(
|
||||
data: Vec<embeddings::Embedding>,
|
||||
model: String,
|
||||
object: Object,
|
||||
usage: embeddings::CreateEmbeddingResponseUsage,
|
||||
) -> CreateEmbeddingResponse {
|
||||
CreateEmbeddingResponse {
|
||||
data,
|
||||
model,
|
||||
object,
|
||||
usage: Box::new(usage),
|
||||
}
|
||||
}
|
||||
}
|
||||
/// The object type, which is always \"list\".
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
pub enum Object {
|
||||
#[serde(rename = "list")]
|
||||
List,
|
||||
}
|
||||
|
||||
impl Default for Object {
|
||||
fn default() -> Object {
|
||||
Self::List
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use crate::embeddings;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// CreateEmbeddingResponseUsage : The usage information for the request.
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingResponseUsage {
|
||||
/// The number of tokens used by the prompt.
|
||||
#[serde(rename = "prompt_tokens")]
|
||||
pub prompt_tokens: i32,
|
||||
/// The total number of tokens used by the request.
|
||||
#[serde(rename = "total_tokens")]
|
||||
pub total_tokens: i32,
|
||||
}
|
||||
|
||||
impl CreateEmbeddingResponseUsage {
|
||||
/// The usage information for the request.
|
||||
pub fn new(prompt_tokens: i32, total_tokens: i32) -> CreateEmbeddingResponseUsage {
|
||||
CreateEmbeddingResponseUsage {
|
||||
prompt_tokens,
|
||||
total_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
49
crates/public_types/src/embeddings/embedding.rs
Normal file
49
crates/public_types/src/embeddings/embedding.rs
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use crate::embeddings;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Embedding : Represents an embedding vector returned by embedding endpoint.
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Embedding {
|
||||
/// The index of the embedding in the list of embeddings.
|
||||
#[serde(rename = "index")]
|
||||
pub index: i32,
|
||||
/// The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings).
|
||||
#[serde(rename = "embedding")]
|
||||
pub embedding: Vec<f64>,
|
||||
/// The object type, which is always \"embedding\"
|
||||
#[serde(rename = "object")]
|
||||
pub object: Object,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
/// Represents an embedding vector returned by embedding endpoint.
|
||||
pub fn new(index: i32, embedding: Vec<f64>, object: Object) -> Embedding {
|
||||
Embedding {
|
||||
index,
|
||||
embedding,
|
||||
object,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// The object type, which is always \"embedding\"
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
pub enum Object {
|
||||
#[serde(rename = "embedding")]
|
||||
Embedding,
|
||||
}
|
||||
|
||||
impl Default for Object {
|
||||
fn default() -> Object {
|
||||
Self::Embedding
|
||||
}
|
||||
}
|
||||
10
crates/public_types/src/embeddings/mod.rs
Normal file
10
crates/public_types/src/embeddings/mod.rs
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
pub mod create_embedding_request;
|
||||
pub use self::create_embedding_request::CreateEmbeddingRequest;
|
||||
pub mod create_embedding_request_input;
|
||||
pub use self::create_embedding_request_input::CreateEmbeddingRequestInput;
|
||||
pub mod create_embedding_response;
|
||||
pub use self::create_embedding_response::CreateEmbeddingResponse;
|
||||
pub mod create_embedding_response_usage;
|
||||
pub use self::create_embedding_response_usage::CreateEmbeddingResponseUsage;
|
||||
pub mod embedding;
|
||||
pub use self::embedding::Embedding;
|
||||
93
crates/public_types/src/http.rs
Normal file
93
crates/public_types/src/http.rs
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
use crate::stats::{Gauge, IncrementingMetric};
|
||||
use derivative::Derivative;
|
||||
use log::debug;
|
||||
use proxy_wasm::{traits::Context, types::Status};
|
||||
use serde::Serialize;
|
||||
use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration};
|
||||
|
||||
#[derive(Derivative, Serialize)]
|
||||
#[derivative(Debug)]
|
||||
pub struct CallArgs<'a> {
|
||||
upstream: &'a str,
|
||||
path: &'a str,
|
||||
headers: Vec<(&'a str, &'a str)>,
|
||||
#[derivative(Debug = "ignore")]
|
||||
body: Option<&'a [u8]>,
|
||||
trailers: Vec<(&'a str, &'a str)>,
|
||||
timeout: Duration,
|
||||
}
|
||||
|
||||
impl<'a> CallArgs<'a> {
|
||||
pub fn new(
|
||||
upstream: &'a str,
|
||||
path: &'a str,
|
||||
headers: Vec<(&'a str, &'a str)>,
|
||||
body: Option<&'a [u8]>,
|
||||
trailers: Vec<(&'a str, &'a str)>,
|
||||
timeout: Duration,
|
||||
) -> Self {
|
||||
CallArgs {
|
||||
upstream,
|
||||
path,
|
||||
headers,
|
||||
body,
|
||||
trailers,
|
||||
timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ClientError {
|
||||
#[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")]
|
||||
DispatchError {
|
||||
upstream_name: String,
|
||||
path: String,
|
||||
internal_status: Status,
|
||||
},
|
||||
}
|
||||
|
||||
pub trait Client: Context {
|
||||
type CallContext: Debug;
|
||||
|
||||
fn http_call(
|
||||
&self,
|
||||
call_args: CallArgs,
|
||||
call_context: Self::CallContext,
|
||||
) -> Result<u32, ClientError> {
|
||||
debug!(
|
||||
"dispatching http call with args={:?} context={:?}",
|
||||
call_args, call_context
|
||||
);
|
||||
|
||||
match self.dispatch_http_call(
|
||||
call_args.upstream,
|
||||
call_args.headers,
|
||||
call_args.body,
|
||||
call_args.trailers,
|
||||
call_args.timeout,
|
||||
) {
|
||||
Ok(id) => {
|
||||
self.add_call_context(id, call_context);
|
||||
Ok(id)
|
||||
}
|
||||
Err(status) => Err(ClientError::DispatchError {
|
||||
upstream_name: String::from(call_args.upstream),
|
||||
path: String::from(call_args.path),
|
||||
internal_status: status,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_call_context(&self, id: u32, call_context: Self::CallContext) {
|
||||
let callouts = self.callouts();
|
||||
if callouts.borrow_mut().insert(id, call_context).is_some() {
|
||||
panic!("Duplicate http call with id={}", id);
|
||||
}
|
||||
self.active_http_calls().increment(1);
|
||||
}
|
||||
|
||||
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>>;
|
||||
|
||||
fn active_http_calls(&self) -> &Gauge;
|
||||
}
|
||||
8
crates/public_types/src/lib.rs
Normal file
8
crates/public_types/src/lib.rs
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
#![allow(unused_imports)]
|
||||
|
||||
pub mod common_types;
|
||||
pub mod configuration;
|
||||
pub mod embeddings;
|
||||
pub mod consts;
|
||||
pub mod http;
|
||||
pub mod stats;
|
||||
103
crates/public_types/src/stats.rs
Normal file
103
crates/public_types/src/stats.rs
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
use log::error;
|
||||
use proxy_wasm::hostcalls;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait Metric {
|
||||
fn id(&self) -> u32;
|
||||
fn value(&self) -> Result<u64, String> {
|
||||
match hostcalls::get_metric(self.id()) {
|
||||
Ok(value) => Ok(value),
|
||||
Err(Status::NotFound) => Err(format!("metric not found: {}", self.id())),
|
||||
Err(err) => Err(format!("unexpected status: {:?}", err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait IncrementingMetric: Metric {
|
||||
fn increment(&self, offset: i64) {
|
||||
match hostcalls::increment_metric(self.id(), offset) {
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("error incrementing metric: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait RecordingMetric: Metric {
|
||||
fn record(&self, value: u64) {
|
||||
match hostcalls::record_metric(self.id(), value) {
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("error recording metric: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Counter {
|
||||
id: u32,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl Counter {
|
||||
pub fn new(name: String) -> Counter {
|
||||
let returned_id = hostcalls::define_metric(MetricType::Counter, &name)
|
||||
.expect("failed to define counter '{}', name");
|
||||
Counter { id: returned_id }
|
||||
}
|
||||
}
|
||||
|
||||
impl Metric for Counter {
|
||||
fn id(&self) -> u32 {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl IncrementingMetric for Counter {}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Gauge {
|
||||
id: u32,
|
||||
}
|
||||
|
||||
impl Gauge {
|
||||
pub fn new(name: String) -> Gauge {
|
||||
let returned_id = hostcalls::define_metric(MetricType::Gauge, &name)
|
||||
.expect("failed to define gauge '{}', name");
|
||||
Gauge { id: returned_id }
|
||||
}
|
||||
}
|
||||
|
||||
impl Metric for Gauge {
|
||||
fn id(&self) -> u32 {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
/// For state of the world updates
|
||||
impl RecordingMetric for Gauge {}
|
||||
/// For offset deltas
|
||||
impl IncrementingMetric for Gauge {}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct Histogram {
|
||||
id: u32,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl Histogram {
|
||||
pub fn new(name: String) -> Histogram {
|
||||
let returned_id = hostcalls::define_metric(MetricType::Histogram, &name)
|
||||
.expect("failed to define histogram '{}', name");
|
||||
Histogram { id: returned_id }
|
||||
}
|
||||
}
|
||||
|
||||
impl Metric for Histogram {
|
||||
fn id(&self) -> u32 {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordingMetric for Histogram {}
|
||||
Loading…
Add table
Add a link
Reference in a new issue