diff --git a/arch/Cargo.lock b/arch/Cargo.lock index 201f2e68..4886134a 100644 --- a/arch/Cargo.lock +++ b/arch/Cargo.lock @@ -759,6 +759,7 @@ dependencies = [ "serde_json", "serde_yaml", "serial_test", + "thiserror", "tiktoken-rs", ] @@ -1060,7 +1061,7 @@ dependencies = [ [[package]] name = "proxy-wasm-test-framework" version = "0.1.0" -source = "git+https://github.com/katanemo/test-framework.git?branch=main#c2511cd9030705e14d5f60aca77d6c96c81c6dfa" +source = "git+https://github.com/katanemo/test-framework.git?branch=new#c2511cd9030705e14d5f60aca77d6c96c81c6dfa" dependencies = [ "anyhow", "cfg-if 0.1.10", @@ -1490,18 +1491,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", diff --git a/arch/Cargo.toml b/arch/Cargo.toml index 69750f5c..e7a5d721 100644 --- a/arch/Cargo.toml +++ b/arch/Cargo.toml @@ -20,7 +20,8 @@ governor = { version = "0.6.3", default-features = false, features = ["no_std"]} tiktoken-rs = "0.5.9" acap = "0.3.0" rand = "0.8.5" +thiserror = "1.0.64" [dev-dependencies] -proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" } +proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" } serial_test = "3.1.1" diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index 5055f682..f6349681 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -38,6 +38,8 @@ properties: properties: name: type: string + provider: + type: string access_key: type: string model: @@ -47,6 +49,7 @@ properties: additionalProperties: false required: - name + - provider - access_key - model overrides: @@ -112,7 +115,7 @@ properties: items: type: object properties: - provider: + model: type: string selector: type: object @@ -138,7 +141,7 @@ properties: - unit additionalProperties: false required: - - provider + - model - selector - limit additionalProperties: false diff --git a/arch/config_generator.py b/arch/config_generator.py index 2fb0ded7..c3282f31 100644 --- a/arch/config_generator.py +++ b/arch/config_generator.py @@ -7,17 +7,32 @@ ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.temp ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/config/arch_config.yaml') ENVOY_CONFIG_FILE_RENDERED = os.getenv('ENVOY_CONFIG_FILE_RENDERED', '/etc/envoy/envoy.yaml') ARCH_CONFIG_SCHEMA_FILE = os.getenv('ARCH_CONFIG_SCHEMA_FILE', 'arch_config_schema.yaml') +OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', False) +MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY', False) + +def add_secret_key_to_llm_providers(config_yaml) : + llm_providers = [] + for llm_provider in config_yaml.get("llm_providers", []): + if llm_provider['access_key'] == "$MISTRAL_ACCESS_KEY": + llm_provider['access_key'] = MISTRAL_API_KEY + elif llm_provider['access_key'] == "$OPENAI_ACCESS_KEY": + llm_provider['access_key'] = OPENAI_API_KEY + else: + llm_provider.pop('access_key') + llm_providers.append(llm_provider) + config_yaml["llm_providers"] = llm_providers + return config_yaml env = Environment(loader=FileSystemLoader('./')) template = env.get_template('envoy.template.yaml') with open(ARCH_CONFIG_FILE, 'r') as file: - katanemo_config = file.read() + arch_config_string = file.read() with open(ARCH_CONFIG_SCHEMA_FILE, 'r') as file: arch_config_schema = file.read() -config_yaml = yaml.safe_load(katanemo_config) +config_yaml = yaml.safe_load(arch_config_string) config_schema_yaml = yaml.safe_load(arch_config_schema) try: @@ -54,9 +69,16 @@ for name, endpoint_details in endpoints.items(): print("updated clusters", inferred_clusters) +config_yaml = add_secret_key_to_llm_providers(config_yaml) +arch_llm_providers = config_yaml["llm_providers"] +arch_config_string = yaml.dump(config_yaml) + +print("llm_providers:", arch_llm_providers) + data = { - 'katanemo_config': katanemo_config, - 'arch_clusters': inferred_clusters + 'arch_config': arch_config_string, + 'arch_clusters': inferred_clusters, + 'arch_llm_providers': arch_llm_providers } rendered = template.render(data) diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index 5d453e67..eb83f328 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -34,26 +34,18 @@ static_resources: auto_host_rewrite: true cluster: mistral_7b_instruct timeout: 60s + {% for provider in arch_llm_providers %} - match: - prefix: "/v1/chat/completions" + prefix: "/" headers: - name: "x-arch-llm-provider" string_match: - exact: openai + exact: {{ provider.name }} route: auto_host_rewrite: true - cluster: openai - timeout: 60s - - match: - prefix: "/v1/chat/completions" - headers: - - name: "x-arch-llm-provider" - string_match: - exact: mistral - route: - auto_host_rewrite: true - cluster: mistral + cluster: {{ provider.provider }} timeout: 60s + {% endfor %} http_filters: - name: envoy.filters.http.wasm typed_config: @@ -65,7 +57,7 @@ static_resources: configuration: "@type": "type.googleapis.com/google.protobuf.StringValue" value: | - {{ katanemo_config | indent(30) }} + {{ arch_config | indent(30) }} vm_config: runtime: "envoy.wasm.runtime.v8" code: @@ -75,9 +67,6 @@ static_resources: typed_config: "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router clusters: - # LLM Host - # Embedding Providers - # External LLM Providers - name: openai connect_timeout: 5s dns_lookup_family: V4_ONLY diff --git a/arch/src/consts.rs b/arch/src/consts.rs index ccee4640..4962a75a 100644 --- a/arch/src/consts.rs +++ b/arch/src/consts.rs @@ -10,3 +10,4 @@ pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes pub const MODEL_SERVER_NAME: &str = "model_server"; pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const ARCH_MESSAGES_KEY: &str = "arch_messages"; +pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; diff --git a/arch/src/filter_context.rs b/arch/src/filter_context.rs index 61d146ed..853f5cdc 100644 --- a/arch/src/filter_context.rs +++ b/arch/src/filter_context.rs @@ -1,4 +1,5 @@ use crate::consts::{DEFAULT_EMBEDDING_MODEL, MODEL_SERVER_NAME}; +use crate::llm_providers::LlmProviders; use crate::ratelimit; use crate::stats::{Counter, Gauge, RecordingMetric}; use crate::stream_context::StreamContext; @@ -44,10 +45,11 @@ pub struct FilterContext { metrics: Rc, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. callouts: HashMap, - config: Option, overrides: Rc>, prompt_targets: Rc>>, + // This should be Option>, because StreamContext::new() should get an Rc not Option>. prompt_guards: Rc>, + llm_providers: Option>, } pub fn embeddings_store() -> &'static RwLock> { @@ -62,11 +64,11 @@ impl FilterContext { pub fn new() -> FilterContext { FilterContext { callouts: HashMap::new(), - config: None, metrics: Rc::new(WasmMetrics::new()), prompt_targets: Rc::new(RwLock::new(HashMap::new())), overrides: Rc::new(None), prompt_guards: Rc::new(Some(PromptGuards::default())), + llm_providers: None, } } @@ -219,42 +221,35 @@ impl Context for FilterContext { // RootContext allows the Rust code to reach into the Envoy Config impl RootContext for FilterContext { fn on_configure(&mut self, _: usize) -> bool { - if let Some(config_bytes) = self.get_plugin_configuration() { - self.config = serde_yaml::from_slice(&config_bytes).unwrap(); + let config_bytes = self + .get_plugin_configuration() + .expect("Arch config cannot be empty"); - if let Some(overrides_config) = self - .config - .as_mut() - .and_then(|config| config.overrides.as_mut()) - { - self.overrides = Rc::new(Some(std::mem::take(overrides_config))); - } + let config: Configuration = match serde_yaml::from_slice(&config_bytes) { + Ok(config) => config, + Err(err) => panic!("Invalid arch config \"{:?}\"", err), + }; - for pt in self.config.clone().unwrap().prompt_targets { - self.prompt_targets - .write() - .unwrap() - .insert(pt.name.clone(), pt.clone()); - } + self.overrides = Rc::new(config.overrides); - debug!("set configuration object"); - - if let Some(ratelimits_config) = self - .config - .as_mut() - .and_then(|config| config.ratelimits.as_mut()) - { - ratelimit::ratelimits(Some(std::mem::take(ratelimits_config))); - } - - if let Some(prompt_guards) = self - .config - .as_mut() - .and_then(|config| config.prompt_guards.as_mut()) - { - self.prompt_guards = Rc::new(Some(std::mem::take(prompt_guards))); - } + for pt in config.prompt_targets { + self.prompt_targets + .write() + .unwrap() + .insert(pt.name.clone(), pt.clone()); } + + ratelimit::ratelimits(config.ratelimits); + + if let Some(prompt_guards) = config.prompt_guards { + self.prompt_guards = Rc::new(Some(prompt_guards)) + } + + match config.llm_providers.try_into() { + Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)), + Err(err) => panic!("{err}"), + } + true } @@ -269,6 +264,11 @@ impl RootContext for FilterContext { Rc::clone(&self.prompt_targets), Rc::clone(&self.prompt_guards), Rc::clone(&self.overrides), + Rc::clone( + self.llm_providers + .as_ref() + .expect("LLM Providers must exist when Streams are being created"), + ), ))) } diff --git a/arch/src/llm_providers.rs b/arch/src/llm_providers.rs index c698bd1f..75d57817 100644 --- a/arch/src/llm_providers.rs +++ b/arch/src/llm_providers.rs @@ -1,47 +1,69 @@ -#[non_exhaustive] -pub struct LlmProviders; +use public_types::configuration::LlmProvider; +use std::collections::HashMap; +use std::rc::Rc; + +#[derive(Debug)] +pub struct LlmProviders { + providers: HashMap>, + default: Option>, +} impl LlmProviders { - pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider { - name: "openai", - api_key_header: "x-arch-openai-api-key", - model: "gpt-3.5-turbo", - }; - pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider { - name: "mistral", - api_key_header: "x-arch-mistral-api-key", - model: "mistral-large-latest", - }; + pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc> { + self.providers.iter() + } - pub const VARIANTS: &'static [LlmProvider<'static>] = - &[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER]; -} + pub fn default(&self) -> Option> { + self.default.as_ref().map(|rc| rc.clone()) + } -pub struct LlmProvider<'prov> { - name: &'prov str, - api_key_header: &'prov str, - model: &'prov str, -} - -impl AsRef for LlmProvider<'_> { - fn as_ref(&self) -> &str { - self.name + pub fn get(&self, name: &str) -> Option> { + self.providers.get(name).map(|rc| rc.clone()) } } -impl std::fmt::Display for LlmProvider<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.name) - } +#[derive(thiserror::Error, Debug)] +pub enum LlmProvidersNewError { + #[error("There must be at least one LLM Provider")] + EmptySource, + #[error("There must be at most one default LLM Provider")] + MoreThanOneDefault, + #[error("\'{0}\' is not a unique name")] + DuplicateName(String), } -impl LlmProvider<'_> { - pub fn api_key_header(&self) -> &str { - self.api_key_header - } +impl TryFrom> for LlmProviders { + type Error = LlmProvidersNewError; - pub fn choose_model(&self) -> &str { - // In the future this can be a more complex function balancing reliability, cost, performance, etc. - self.model + fn try_from(llm_providers_config: Vec) -> Result { + if llm_providers_config.is_empty() { + return Err(LlmProvidersNewError::EmptySource); + } + + let mut llm_providers = LlmProviders { + providers: HashMap::new(), + default: None, + }; + + for llm_provider in llm_providers_config { + let llm_provider: Rc = Rc::new(llm_provider); + if llm_provider.default.unwrap_or_default() { + match llm_providers.default { + Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault), + None => llm_providers.default = Some(Rc::clone(&llm_provider)), + } + } + + // Insert and check that there is no other provider with the same name. + let name = llm_provider.name.clone(); + if llm_providers + .providers + .insert(name.clone(), llm_provider) + .is_some() + { + return Err(LlmProvidersNewError::DuplicateName(name)); + } + } + Ok(llm_providers) } } diff --git a/arch/src/ratelimit.rs b/arch/src/ratelimit.rs index b7e206be..42554bbe 100644 --- a/arch/src/ratelimit.rs +++ b/arch/src/ratelimit.rs @@ -54,10 +54,7 @@ impl RatelimitMap { for ratelimit_config in ratelimits_config { let limit = DefaultKeyedRateLimiter::keyed(get_quota(ratelimit_config.limit)); - match new_ratelimit_map - .datastore - .get_mut(&ratelimit_config.provider) - { + match new_ratelimit_map.datastore.get_mut(&ratelimit_config.model) { Some(limits) => match limits.get_mut(&ratelimit_config.selector) { Some(_) => { panic!("repeated selector. Selectors per provider must be unique") @@ -72,7 +69,7 @@ impl RatelimitMap { let new_hash_map = HashMap::from([(ratelimit_config.selector, limit)]); new_ratelimit_map .datastore - .insert(ratelimit_config.provider, new_hash_map); + .insert(ratelimit_config.model, new_hash_map); } } } @@ -142,7 +139,7 @@ fn get_quota(limit: Limit) -> Quota { #[test] fn non_existent_provider_is_ok() { let ratelimits_config = vec![Ratelimit { - provider: String::from("provider"), + model: String::from("provider"), selector: configuration::Header { key: String::from("only-key"), value: None, @@ -170,7 +167,7 @@ fn non_existent_provider_is_ok() { #[test] fn non_existent_key_is_ok() { let ratelimits_config = vec![Ratelimit { - provider: String::from("provider"), + model: String::from("provider"), selector: configuration::Header { key: String::from("only-key"), value: None, @@ -198,7 +195,7 @@ fn non_existent_key_is_ok() { #[test] fn specific_limit_does_not_catch_non_specific_value() { let ratelimits_config = vec![Ratelimit { - provider: String::from("provider"), + model: String::from("provider"), selector: configuration::Header { key: String::from("key"), value: Some(String::from("value")), @@ -226,7 +223,7 @@ fn specific_limit_does_not_catch_non_specific_value() { #[test] fn specific_limit_is_hit() { let ratelimits_config = vec![Ratelimit { - provider: String::from("provider"), + model: String::from("provider"), selector: configuration::Header { key: String::from("key"), value: Some(String::from("value")), @@ -254,7 +251,7 @@ fn specific_limit_is_hit() { #[test] fn non_specific_key_has_different_limits_for_different_values() { let ratelimits_config = vec![Ratelimit { - provider: String::from("provider"), + model: String::from("provider"), selector: configuration::Header { key: String::from("only-key"), value: None, @@ -308,7 +305,7 @@ fn non_specific_key_has_different_limits_for_different_values() { fn different_provider_can_have_different_limits_with_the_same_keys() { let ratelimits_config = vec![ Ratelimit { - provider: String::from("first_provider"), + model: String::from("first_provider"), selector: configuration::Header { key: String::from("key"), value: Some(String::from("value")), @@ -319,7 +316,7 @@ fn different_provider_can_have_different_limits_with_the_same_keys() { }, }, Ratelimit { - provider: String::from("second_provider"), + model: String::from("second_provider"), selector: configuration::Header { key: String::from("key"), value: Some(String::from("value")), @@ -391,7 +388,7 @@ mod test { #[test] fn different_threads_have_same_ratelimit_data_structure() { let ratelimits_config = Some(vec![Ratelimit { - provider: String::from("provider"), + model: String::from("provider"), selector: configuration::Header { key: String::from("key"), value: Some(String::from("value")), diff --git a/arch/src/routing.rs b/arch/src/routing.rs index 5b0f883d..1f23f383 100644 --- a/arch/src/routing.rs +++ b/arch/src/routing.rs @@ -1,13 +1,42 @@ -use crate::llm_providers::{LlmProvider, LlmProviders}; -use rand::{seq::SliceRandom, thread_rng}; +use std::rc::Rc; -pub fn get_llm_provider<'hostname>(deterministic: bool) -> &'static LlmProvider<'hostname> { - if deterministic { - &LlmProviders::OPENAI_PROVIDER - } else { - let mut rng = thread_rng(); - LlmProviders::VARIANTS - .choose(&mut rng) - .expect("There should always be at least one llm provider") +use crate::llm_providers::LlmProviders; +use public_types::configuration::LlmProvider; +use rand::{seq::IteratorRandom, thread_rng}; + +pub enum ProviderHint { + Default, + Name(String), +} + +impl From for ProviderHint { + fn from(value: String) -> Self { + match value.as_str() { + "default" => ProviderHint::Default, + _ => ProviderHint::Name(value), + } } } + +pub fn get_llm_provider( + llm_providers: &LlmProviders, + provider_hint: Option, +) -> Rc { + let maybe_provider = provider_hint.and_then(|hint| match hint { + ProviderHint::Default => llm_providers.default(), + // FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST? + ProviderHint::Name(name) => llm_providers.get(&name), + }); + + if let Some(provider) = maybe_provider { + return provider; + } + + let mut rng = thread_rng(); + llm_providers + .iter() + .choose(&mut rng) + .expect("There should always be at least one llm provider") + .1 + .clone() +} diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index 97f495a7..e339a765 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -1,14 +1,13 @@ use crate::consts::{ - ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_ROUTING_HEADER, ARC_FC_CLUSTER, - DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, - MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE, + ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, + ARC_FC_CLUSTER, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, + GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE, }; use crate::filter_context::{embeddings_store, WasmMetrics}; -use crate::llm_providers::{LlmProvider, LlmProviders}; +use crate::llm_providers::LlmProviders; use crate::ratelimit::Header; use crate::stats::IncrementingMetric; -use crate::tokenizer; -use crate::{ratelimit, routing}; +use crate::{ratelimit, routing, tokenizer}; use acap::cos; use http::StatusCode; use log::{debug, info, warn}; @@ -23,6 +22,7 @@ use public_types::common_types::{ EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest, ZeroShotClassificationResponse, }; +use public_types::configuration::LlmProvider; use public_types::configuration::{Overrides, PromptGuards, PromptTarget}; use public_types::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, @@ -53,16 +53,17 @@ pub struct CallContext { } pub struct StreamContext { - pub context_id: u32, - pub metrics: Rc, - pub prompt_targets: Rc>>, - pub overrides: Rc>, + context_id: u32, + metrics: Rc, + prompt_targets: Rc>>, + overrides: Rc>, callouts: HashMap, ratelimit_selector: Option
, streaming_response: bool, response_tokens: usize, chat_completions_request: bool, - llm_provider: Option<&'static LlmProvider<'static>>, + llm_providers: Rc, + llm_provider: Option>, prompt_guards: Rc>, } @@ -73,6 +74,7 @@ impl StreamContext { prompt_targets: Rc>>, prompt_guards: Rc>, overrides: Rc>, + llm_providers: Rc, ) -> Self { StreamContext { context_id, @@ -83,6 +85,7 @@ impl StreamContext { streaming_response: false, response_tokens: 0, chat_completions_request: false, + llm_providers, llm_provider: None, prompt_guards, overrides, @@ -90,27 +93,35 @@ impl StreamContext { } fn llm_provider(&self) -> &LlmProvider { self.llm_provider + .as_ref() .expect("the provider should be set when asked for it") } + fn select_llm_provider(&mut self) { + let provider_hint = self + .get_http_request_header(ARCH_PROVIDER_HINT_HEADER) + .map(|provider_name| provider_name.into()); + + self.llm_provider = Some(routing::get_llm_provider( + &self.llm_providers, + provider_hint, + )); + } + fn add_routing_header(&mut self) { - self.add_http_request_header(ARCH_ROUTING_HEADER, self.llm_provider().as_ref()); + self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name); } fn modify_auth_headers(&mut self) -> Result<(), String> { - let llm_provider_api_key_value = self - .get_http_request_header(self.llm_provider().api_key_header()) - .ok_or(format!("missing {} api key", self.llm_provider()))?; + let llm_provider_api_key_value = self.llm_provider().access_key.as_ref().ok_or(format!( + "No access key configured for selected LLM Provider \"{}\"", + self.llm_provider() + ))?; let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value); self.set_http_request_header("Authorization", Some(&authorization_header_value)); - // sanitize passed in api keys - for provider in LlmProviders::VARIANTS.iter() { - self.set_http_request_header(provider.api_key_header(), None); - } - Ok(()) } @@ -728,29 +739,13 @@ impl StreamContext { let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); debug!("prompt_guard_resp: {:?}", prompt_guard_resp); - if prompt_guard_resp.jailbreak_verdict.is_some() - && prompt_guard_resp.jailbreak_verdict.unwrap() - { + if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() { //TODO: handle other scenarios like forward to error target - let default_err = "Jailbreak detected. Please refrain from discussing jailbreaking."; - let error_msg = match self.prompt_guards.as_ref() { - Some(prompt_guards) => match prompt_guards - .input_guards - .get(&public_types::configuration::GuardType::Jailbreak) - { - Some(jailbreak) => match jailbreak.on_exception.as_ref() { - Some(on_exception_details) => match on_exception_details.message.as_ref() { - Some(error_msg) => error_msg, - None => default_err, - }, - None => default_err, - }, - None => default_err, - }, - None => default_err, - }; - - return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST)); + let msg = (*self.prompt_guards) + .as_ref() + .and_then(|pg| pg.jailbreak_on_exception_message()) + .unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking."); + return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST)); } self.get_embeddings(callout_context); @@ -900,11 +895,7 @@ impl HttpContext for StreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - let provider_hint = self - .get_http_request_header("x-arch-deterministic-provider") - .is_some(); - self.llm_provider = Some(routing::get_llm_provider(provider_hint)); - + self.select_llm_provider(); self.add_routing_header(); if let Err(error) = self.modify_auth_headers() { self.send_server_error(error, Some(StatusCode::BAD_REQUEST)); @@ -959,7 +950,7 @@ impl HttpContext for StreamContext { }; // Set the model based on the chosen LLM Provider - deserialized_body.model = String::from(self.llm_provider().choose_model()); + deserialized_body.model = String::from(&self.llm_provider().model); self.streaming_response = deserialized_body.stream; if deserialized_body.stream && deserialized_body.stream_options.is_none() { diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs index 336454e0..1c7f6166 100644 --- a/arch/tests/integration.rs +++ b/arch/tests/integration.rs @@ -29,31 +29,18 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) { .call_proxy_on_request_headers(http_context, 0, false) .expect_get_header_map_value( Some(MapType::HttpRequestHeaders), - Some("x-arch-deterministic-provider"), + Some("x-arch-llm-provider-hint"), ) - .returning(Some("true")) + .returning(Some("default")) .expect_add_header_map_value( Some(MapType::HttpRequestHeaders), Some("x-arch-llm-provider"), - Some("openai"), + Some("open-ai-gpt-4"), ) - .expect_get_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-arch-openai-api-key"), - ) - .returning(Some("api-key")) .expect_replace_header_map_value( Some(MapType::HttpRequestHeaders), Some("Authorization"), - Some("Bearer api-key"), - ) - .expect_remove_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-arch-openai-api-key"), - ) - .expect_remove_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-arch-mistral-api-key"), + Some("Bearer secret_key"), ) .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) .expect_get_header_map_value( @@ -190,7 +177,8 @@ endpoints: llm_providers: - name: open-ai-gpt-4 - access_key: $OPEN_AI_API_KEY + provider: openai + access_key: secret_key model: gpt-4 default: true @@ -240,7 +228,7 @@ prompt_targets: You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries: - Use policy number to retrieve insurance claim details ratelimits: - - provider: gpt-3.5-turbo + - model: gpt-4 selector: key: selector-key value: selector-value @@ -267,20 +255,28 @@ fn successful_request_to_open_ai_chat_completions() { .unwrap(); // Setup Filter - let root_context = 1; + let filter_context = 1; + let config = serde_json::to_string(&default_config()).unwrap(); module - .call_proxy_on_context_create(root_context, 0) + .call_proxy_on_context_create(filter_context, 0) .expect_metric_creation(MetricType::Gauge, "active_http_calls") .expect_metric_creation(MetricType::Counter, "ratelimited_rq") .execute_and_expect(ReturnType::None) .unwrap(); + module + .call_proxy_on_configure(filter_context, config.len() as i32) + .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) + .returning(Some(&config)) + .execute_and_expect(ReturnType::Bool(true)) + .unwrap(); + // Setup HTTP Stream let http_context = 2; module - .call_proxy_on_context_create(http_context, root_context) + .call_proxy_on_context_create(http_context, filter_context) .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) .unwrap(); @@ -336,20 +332,28 @@ fn bad_request_to_open_ai_chat_completions() { .unwrap(); // Setup Filter - let root_context = 1; + let filter_context = 1; + let config = serde_json::to_string(&default_config()).unwrap(); module - .call_proxy_on_context_create(root_context, 0) + .call_proxy_on_context_create(filter_context, 0) .expect_metric_creation(MetricType::Gauge, "active_http_calls") .expect_metric_creation(MetricType::Counter, "ratelimited_rq") .execute_and_expect(ReturnType::None) .unwrap(); + module + .call_proxy_on_configure(filter_context, config.len() as i32) + .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) + .returning(Some(&config)) + .execute_and_expect(ReturnType::Bool(true)) + .unwrap(); + // Setup HTTP Stream let http_context = 2; module - .call_proxy_on_context_create(http_context, root_context) + .call_proxy_on_context_create(http_context, filter_context) .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) .unwrap(); @@ -416,7 +420,6 @@ fn request_ratelimited() { .unwrap(); module .call_proxy_on_configure(filter_context, config.len() as i32) - .expect_log(Some(LogLevel::Debug), None) .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) .returning(Some(&config)) .execute_and_expect(ReturnType::Bool(true)) @@ -531,7 +534,6 @@ fn request_not_ratelimited() { .unwrap(); module .call_proxy_on_configure(filter_context, config_str.len() as i32) - .expect_log(Some(LogLevel::Debug), None) .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) .returning(Some(&config_str)) .execute_and_expect(ReturnType::Bool(true)) diff --git a/demos/function_calling/arch_config.yaml b/demos/function_calling/arch_config.yaml index b13721f5..28ada761 100644 --- a/demos/function_calling/arch_config.yaml +++ b/demos/function_calling/arch_config.yaml @@ -11,21 +11,24 @@ endpoints: endpoint: api_server:80 connect_timeout: 0.005s -llm_providers: - - name: open-ai-gpt-4 - access_key: $OPEN_AI_API_KEY - model: gpt-4 - default: true - overrides: # confidence threshold for prompt target intent matching prompt_target_intent_matching_threshold: 0.6 -system_prompt: | - You are a helpful assistant. +llm_providers: + - name: open-ai-gpt-4 + access_key: $OPENAI_ACCESS_KEY + provider: openai + model: gpt-4 + default: true + - name: mistral-large-latest + access_key: $MISTRAL_ACCESS_KEY + provider: mistral + model: large-latest + +system_prompt: You are a helpful assistant. prompt_targets: - - name: weather_forecast description: This function provides realtime weather forecast information for a given city. parameters: @@ -78,7 +81,7 @@ prompt_targets: auto_llm_dispatch_on_response: true ratelimits: - - provider: gpt-3.5-turbo + - model: gpt-4 selector: key: selector-key value: selector-value diff --git a/demos/function_calling/docker-compose.yaml b/demos/function_calling/docker-compose.yaml index 715ff581..3cc689e5 100644 --- a/demos/function_calling/docker-compose.yaml +++ b/demos/function_calling/docker-compose.yaml @@ -24,6 +24,8 @@ services: condition: service_healthy environment: - LOG_LEVEL=debug + - OPENAI_API_KEY=${OPENAI_API_KEY:?error} + - MISTRAL_API_KEY=${MISTRAL_API_KEY:?error} model_server: build: diff --git a/docs/source/_config/prompt-config-full-reference.yml b/docs/source/_config/prompt-config-full-reference.yml index 1c18a508..fad8962c 100644 --- a/docs/source/_config/prompt-config-full-reference.yml +++ b/docs/source/_config/prompt-config-full-reference.yml @@ -31,6 +31,7 @@ endpoints: # Centralized way to manage LLMs, manage keys, retry logic, failover and limits in a central way llm_providers: - name: "OpenAI" + provider: "openai" access_key: $OPENAI_API_KEY model: gpt-4o default: true @@ -45,10 +46,12 @@ llm_providers: unit: "minute" - name: "Mistral8x7b" + provider: "mistral" access_key: $MISTRAL_API_KEY model: "mistral-8x7b" - name: "MistralLocal7b" + provider: "local" model: "mistral-7b-instruct" endpoint: "mistral_local" diff --git a/public_types/src/configuration.rs b/public_types/src/configuration.rs index a2c4fa73..fa4c4bab 100644 --- a/public_types/src/configuration.rs +++ b/public_types/src/configuration.rs @@ -1,7 +1,7 @@ -use std::{collections::HashMap, time::Duration}; - use duration_string::DurationString; use serde::{Deserialize, Deserializer, Serialize}; +use std::fmt::Display; +use std::{collections::HashMap, time::Duration}; #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct Overrides { @@ -59,6 +59,19 @@ pub struct PromptGuards { pub input_guards: HashMap, } +impl PromptGuards { + pub fn jailbreak_on_exception_message(&self) -> Option<&str> { + self.input_guards + .get(&GuardType::Jailbreak)? + .on_exception + .as_ref()? + .message + .as_ref()? + .as_str() + .into() + } +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum GuardType { #[serde(rename = "jailbreak")] @@ -96,7 +109,7 @@ pub struct Header { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Ratelimit { - pub provider: String, + pub model: String, pub selector: Header, pub limit: Limit, } @@ -134,7 +147,7 @@ pub struct EmbeddingProviver { //TODO: use enum for model, but if there is a new model, we need to update the code pub struct LlmProvider { pub name: String, - //TODO: handle env var replacement + pub provider: String, pub access_key: Option, pub model: String, pub default: Option, @@ -142,6 +155,12 @@ pub struct LlmProvider { pub rate_limits: Option, } +impl Display for LlmProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Endpoint { pub endpoint: Option,