Add the ability to use LLM Providers from the Arch config (#112)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-10-03 10:57:01 -07:00 committed by GitHub
parent 1b57a49c9d
commit 8ea917aae5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 295 additions and 210 deletions

11
arch/Cargo.lock generated
View file

@ -759,6 +759,7 @@ dependencies = [
"serde_json",
"serde_yaml",
"serial_test",
"thiserror",
"tiktoken-rs",
]
@ -1060,7 +1061,7 @@ dependencies = [
[[package]]
name = "proxy-wasm-test-framework"
version = "0.1.0"
source = "git+https://github.com/katanemo/test-framework.git?branch=main#c2511cd9030705e14d5f60aca77d6c96c81c6dfa"
source = "git+https://github.com/katanemo/test-framework.git?branch=new#c2511cd9030705e14d5f60aca77d6c96c81c6dfa"
dependencies = [
"anyhow",
"cfg-if 0.1.10",
@ -1490,18 +1491,18 @@ dependencies = [
[[package]]
name = "thiserror"
version = "1.0.63"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.63"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
dependencies = [
"proc-macro2",
"quote",

View file

@ -20,7 +20,8 @@ governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
tiktoken-rs = "0.5.9"
acap = "0.3.0"
rand = "0.8.5"
thiserror = "1.0.64"
[dev-dependencies]
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" }
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
serial_test = "3.1.1"

View file

@ -38,6 +38,8 @@ properties:
properties:
name:
type: string
provider:
type: string
access_key:
type: string
model:
@ -47,6 +49,7 @@ properties:
additionalProperties: false
required:
- name
- provider
- access_key
- model
overrides:
@ -112,7 +115,7 @@ properties:
items:
type: object
properties:
provider:
model:
type: string
selector:
type: object
@ -138,7 +141,7 @@ properties:
- unit
additionalProperties: false
required:
- provider
- model
- selector
- limit
additionalProperties: false

View file

@ -7,17 +7,32 @@ ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.temp
ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/config/arch_config.yaml')
ENVOY_CONFIG_FILE_RENDERED = os.getenv('ENVOY_CONFIG_FILE_RENDERED', '/etc/envoy/envoy.yaml')
ARCH_CONFIG_SCHEMA_FILE = os.getenv('ARCH_CONFIG_SCHEMA_FILE', 'arch_config_schema.yaml')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', False)
MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY', False)
def add_secret_key_to_llm_providers(config_yaml) :
llm_providers = []
for llm_provider in config_yaml.get("llm_providers", []):
if llm_provider['access_key'] == "$MISTRAL_ACCESS_KEY":
llm_provider['access_key'] = MISTRAL_API_KEY
elif llm_provider['access_key'] == "$OPENAI_ACCESS_KEY":
llm_provider['access_key'] = OPENAI_API_KEY
else:
llm_provider.pop('access_key')
llm_providers.append(llm_provider)
config_yaml["llm_providers"] = llm_providers
return config_yaml
env = Environment(loader=FileSystemLoader('./'))
template = env.get_template('envoy.template.yaml')
with open(ARCH_CONFIG_FILE, 'r') as file:
katanemo_config = file.read()
arch_config_string = file.read()
with open(ARCH_CONFIG_SCHEMA_FILE, 'r') as file:
arch_config_schema = file.read()
config_yaml = yaml.safe_load(katanemo_config)
config_yaml = yaml.safe_load(arch_config_string)
config_schema_yaml = yaml.safe_load(arch_config_schema)
try:
@ -54,9 +69,16 @@ for name, endpoint_details in endpoints.items():
print("updated clusters", inferred_clusters)
config_yaml = add_secret_key_to_llm_providers(config_yaml)
arch_llm_providers = config_yaml["llm_providers"]
arch_config_string = yaml.dump(config_yaml)
print("llm_providers:", arch_llm_providers)
data = {
'katanemo_config': katanemo_config,
'arch_clusters': inferred_clusters
'arch_config': arch_config_string,
'arch_clusters': inferred_clusters,
'arch_llm_providers': arch_llm_providers
}
rendered = template.render(data)

View file

@ -34,26 +34,18 @@ static_resources:
auto_host_rewrite: true
cluster: mistral_7b_instruct
timeout: 60s
{% for provider in arch_llm_providers %}
- match:
prefix: "/v1/chat/completions"
prefix: "/"
headers:
- name: "x-arch-llm-provider"
string_match:
exact: openai
exact: {{ provider.name }}
route:
auto_host_rewrite: true
cluster: openai
timeout: 60s
- match:
prefix: "/v1/chat/completions"
headers:
- name: "x-arch-llm-provider"
string_match:
exact: mistral
route:
auto_host_rewrite: true
cluster: mistral
cluster: {{ provider.provider }}
timeout: 60s
{% endfor %}
http_filters:
- name: envoy.filters.http.wasm
typed_config:
@ -65,7 +57,7 @@ static_resources:
configuration:
"@type": "type.googleapis.com/google.protobuf.StringValue"
value: |
{{ katanemo_config | indent(30) }}
{{ arch_config | indent(30) }}
vm_config:
runtime: "envoy.wasm.runtime.v8"
code:
@ -75,9 +67,6 @@ static_resources:
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
clusters:
# LLM Host
# Embedding Providers
# External LLM Providers
- name: openai
connect_timeout: 5s
dns_lookup_family: V4_ONLY

View file

@ -10,3 +10,4 @@ pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";

View file

@ -1,4 +1,5 @@
use crate::consts::{DEFAULT_EMBEDDING_MODEL, MODEL_SERVER_NAME};
use crate::llm_providers::LlmProviders;
use crate::ratelimit;
use crate::stats::{Counter, Gauge, RecordingMetric};
use crate::stream_context::StreamContext;
@ -44,10 +45,11 @@ pub struct FilterContext {
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, CallContext>,
config: Option<Configuration>,
overrides: Rc<Option<Overrides>>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
// This should be Option<Rc<PromptGuards>>, because StreamContext::new() should get an Rc<PromptGuards> not Option<Rc<PromptGuards>>.
prompt_guards: Rc<Option<PromptGuards>>,
llm_providers: Option<Rc<LlmProviders>>,
}
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
@ -62,11 +64,11 @@ impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: HashMap::new(),
config: None,
metrics: Rc::new(WasmMetrics::new()),
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
overrides: Rc::new(None),
prompt_guards: Rc::new(Some(PromptGuards::default())),
llm_providers: None,
}
}
@ -219,42 +221,35 @@ impl Context for FilterContext {
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
let config_bytes = self
.get_plugin_configuration()
.expect("Arch config cannot be empty");
if let Some(overrides_config) = self
.config
.as_mut()
.and_then(|config| config.overrides.as_mut())
{
self.overrides = Rc::new(Some(std::mem::take(overrides_config)));
}
let config: Configuration = match serde_yaml::from_slice(&config_bytes) {
Ok(config) => config,
Err(err) => panic!("Invalid arch config \"{:?}\"", err),
};
for pt in self.config.clone().unwrap().prompt_targets {
self.prompt_targets
.write()
.unwrap()
.insert(pt.name.clone(), pt.clone());
}
self.overrides = Rc::new(config.overrides);
debug!("set configuration object");
if let Some(ratelimits_config) = self
.config
.as_mut()
.and_then(|config| config.ratelimits.as_mut())
{
ratelimit::ratelimits(Some(std::mem::take(ratelimits_config)));
}
if let Some(prompt_guards) = self
.config
.as_mut()
.and_then(|config| config.prompt_guards.as_mut())
{
self.prompt_guards = Rc::new(Some(std::mem::take(prompt_guards)));
}
for pt in config.prompt_targets {
self.prompt_targets
.write()
.unwrap()
.insert(pt.name.clone(), pt.clone());
}
ratelimit::ratelimits(config.ratelimits);
if let Some(prompt_guards) = config.prompt_guards {
self.prompt_guards = Rc::new(Some(prompt_guards))
}
match config.llm_providers.try_into() {
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
Err(err) => panic!("{err}"),
}
true
}
@ -269,6 +264,11 @@ impl RootContext for FilterContext {
Rc::clone(&self.prompt_targets),
Rc::clone(&self.prompt_guards),
Rc::clone(&self.overrides),
Rc::clone(
self.llm_providers
.as_ref()
.expect("LLM Providers must exist when Streams are being created"),
),
)))
}

View file

@ -1,47 +1,69 @@
#[non_exhaustive]
pub struct LlmProviders;
use public_types::configuration::LlmProvider;
use std::collections::HashMap;
use std::rc::Rc;
#[derive(Debug)]
pub struct LlmProviders {
providers: HashMap<String, Rc<LlmProvider>>,
default: Option<Rc<LlmProvider>>,
}
impl LlmProviders {
pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider {
name: "openai",
api_key_header: "x-arch-openai-api-key",
model: "gpt-3.5-turbo",
};
pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider {
name: "mistral",
api_key_header: "x-arch-mistral-api-key",
model: "mistral-large-latest",
};
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc<LlmProvider>> {
self.providers.iter()
}
pub const VARIANTS: &'static [LlmProvider<'static>] =
&[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER];
}
pub fn default(&self) -> Option<Rc<LlmProvider>> {
self.default.as_ref().map(|rc| rc.clone())
}
pub struct LlmProvider<'prov> {
name: &'prov str,
api_key_header: &'prov str,
model: &'prov str,
}
impl AsRef<str> for LlmProvider<'_> {
fn as_ref(&self) -> &str {
self.name
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
self.providers.get(name).map(|rc| rc.clone())
}
}
impl std::fmt::Display for LlmProvider<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
#[derive(thiserror::Error, Debug)]
pub enum LlmProvidersNewError {
#[error("There must be at least one LLM Provider")]
EmptySource,
#[error("There must be at most one default LLM Provider")]
MoreThanOneDefault,
#[error("\'{0}\' is not a unique name")]
DuplicateName(String),
}
impl LlmProvider<'_> {
pub fn api_key_header(&self) -> &str {
self.api_key_header
}
impl TryFrom<Vec<LlmProvider>> for LlmProviders {
type Error = LlmProvidersNewError;
pub fn choose_model(&self) -> &str {
// In the future this can be a more complex function balancing reliability, cost, performance, etc.
self.model
fn try_from(llm_providers_config: Vec<LlmProvider>) -> Result<Self, Self::Error> {
if llm_providers_config.is_empty() {
return Err(LlmProvidersNewError::EmptySource);
}
let mut llm_providers = LlmProviders {
providers: HashMap::new(),
default: None,
};
for llm_provider in llm_providers_config {
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
if llm_provider.default.unwrap_or_default() {
match llm_providers.default {
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
None => llm_providers.default = Some(Rc::clone(&llm_provider)),
}
}
// Insert and check that there is no other provider with the same name.
let name = llm_provider.name.clone();
if llm_providers
.providers
.insert(name.clone(), llm_provider)
.is_some()
{
return Err(LlmProvidersNewError::DuplicateName(name));
}
}
Ok(llm_providers)
}
}

View file

@ -54,10 +54,7 @@ impl RatelimitMap {
for ratelimit_config in ratelimits_config {
let limit = DefaultKeyedRateLimiter::keyed(get_quota(ratelimit_config.limit));
match new_ratelimit_map
.datastore
.get_mut(&ratelimit_config.provider)
{
match new_ratelimit_map.datastore.get_mut(&ratelimit_config.model) {
Some(limits) => match limits.get_mut(&ratelimit_config.selector) {
Some(_) => {
panic!("repeated selector. Selectors per provider must be unique")
@ -72,7 +69,7 @@ impl RatelimitMap {
let new_hash_map = HashMap::from([(ratelimit_config.selector, limit)]);
new_ratelimit_map
.datastore
.insert(ratelimit_config.provider, new_hash_map);
.insert(ratelimit_config.model, new_hash_map);
}
}
}
@ -142,7 +139,7 @@ fn get_quota(limit: Limit) -> Quota {
#[test]
fn non_existent_provider_is_ok() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
model: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
@ -170,7 +167,7 @@ fn non_existent_provider_is_ok() {
#[test]
fn non_existent_key_is_ok() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
model: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
@ -198,7 +195,7 @@ fn non_existent_key_is_ok() {
#[test]
fn specific_limit_does_not_catch_non_specific_value() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
model: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
@ -226,7 +223,7 @@ fn specific_limit_does_not_catch_non_specific_value() {
#[test]
fn specific_limit_is_hit() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
model: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
@ -254,7 +251,7 @@ fn specific_limit_is_hit() {
#[test]
fn non_specific_key_has_different_limits_for_different_values() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
model: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
@ -308,7 +305,7 @@ fn non_specific_key_has_different_limits_for_different_values() {
fn different_provider_can_have_different_limits_with_the_same_keys() {
let ratelimits_config = vec![
Ratelimit {
provider: String::from("first_provider"),
model: String::from("first_provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
@ -319,7 +316,7 @@ fn different_provider_can_have_different_limits_with_the_same_keys() {
},
},
Ratelimit {
provider: String::from("second_provider"),
model: String::from("second_provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
@ -391,7 +388,7 @@ mod test {
#[test]
fn different_threads_have_same_ratelimit_data_structure() {
let ratelimits_config = Some(vec![Ratelimit {
provider: String::from("provider"),
model: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),

View file

@ -1,13 +1,42 @@
use crate::llm_providers::{LlmProvider, LlmProviders};
use rand::{seq::SliceRandom, thread_rng};
use std::rc::Rc;
pub fn get_llm_provider<'hostname>(deterministic: bool) -> &'static LlmProvider<'hostname> {
if deterministic {
&LlmProviders::OPENAI_PROVIDER
} else {
let mut rng = thread_rng();
LlmProviders::VARIANTS
.choose(&mut rng)
.expect("There should always be at least one llm provider")
use crate::llm_providers::LlmProviders;
use public_types::configuration::LlmProvider;
use rand::{seq::IteratorRandom, thread_rng};
pub enum ProviderHint {
Default,
Name(String),
}
impl From<String> for ProviderHint {
fn from(value: String) -> Self {
match value.as_str() {
"default" => ProviderHint::Default,
_ => ProviderHint::Name(value),
}
}
}
pub fn get_llm_provider(
llm_providers: &LlmProviders,
provider_hint: Option<ProviderHint>,
) -> Rc<LlmProvider> {
let maybe_provider = provider_hint.and_then(|hint| match hint {
ProviderHint::Default => llm_providers.default(),
// FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST?
ProviderHint::Name(name) => llm_providers.get(&name),
});
if let Some(provider) = maybe_provider {
return provider;
}
let mut rng = thread_rng();
llm_providers
.iter()
.choose(&mut rng)
.expect("There should always be at least one llm provider")
.1
.clone()
}

View file

@ -1,14 +1,13 @@
use crate::consts::{
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_ROUTING_HEADER, ARC_FC_CLUSTER,
DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO,
MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
ARC_FC_CLUSTER, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD,
GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{embeddings_store, WasmMetrics};
use crate::llm_providers::{LlmProvider, LlmProviders};
use crate::llm_providers::LlmProviders;
use crate::ratelimit::Header;
use crate::stats::IncrementingMetric;
use crate::tokenizer;
use crate::{ratelimit, routing};
use crate::{ratelimit, routing, tokenizer};
use acap::cos;
use http::StatusCode;
use log::{debug, info, warn};
@ -23,6 +22,7 @@ use public_types::common_types::{
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
};
use public_types::configuration::LlmProvider;
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -53,16 +53,17 @@ pub struct CallContext {
}
pub struct StreamContext {
pub context_id: u32,
pub metrics: Rc<WasmMetrics>,
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
pub overrides: Rc<Option<Overrides>>,
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
overrides: Rc<Option<Overrides>>,
callouts: HashMap<u32, CallContext>,
ratelimit_selector: Option<Header>,
streaming_response: bool,
response_tokens: usize,
chat_completions_request: bool,
llm_provider: Option<&'static LlmProvider<'static>>,
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
prompt_guards: Rc<Option<PromptGuards>>,
}
@ -73,6 +74,7 @@ impl StreamContext {
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
prompt_guards: Rc<Option<PromptGuards>>,
overrides: Rc<Option<Overrides>>,
llm_providers: Rc<LlmProviders>,
) -> Self {
StreamContext {
context_id,
@ -83,6 +85,7 @@ impl StreamContext {
streaming_response: false,
response_tokens: 0,
chat_completions_request: false,
llm_providers,
llm_provider: None,
prompt_guards,
overrides,
@ -90,27 +93,35 @@ impl StreamContext {
}
fn llm_provider(&self) -> &LlmProvider {
self.llm_provider
.as_ref()
.expect("the provider should be set when asked for it")
}
fn select_llm_provider(&mut self) {
let provider_hint = self
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.map(|provider_name| provider_name.into());
self.llm_provider = Some(routing::get_llm_provider(
&self.llm_providers,
provider_hint,
));
}
fn add_routing_header(&mut self) {
self.add_http_request_header(ARCH_ROUTING_HEADER, self.llm_provider().as_ref());
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
}
fn modify_auth_headers(&mut self) -> Result<(), String> {
let llm_provider_api_key_value = self
.get_http_request_header(self.llm_provider().api_key_header())
.ok_or(format!("missing {} api key", self.llm_provider()))?;
let llm_provider_api_key_value = self.llm_provider().access_key.as_ref().ok_or(format!(
"No access key configured for selected LLM Provider \"{}\"",
self.llm_provider()
))?;
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
self.set_http_request_header("Authorization", Some(&authorization_header_value));
// sanitize passed in api keys
for provider in LlmProviders::VARIANTS.iter() {
self.set_http_request_header(provider.api_key_header(), None);
}
Ok(())
}
@ -728,29 +739,13 @@ impl StreamContext {
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
if prompt_guard_resp.jailbreak_verdict.is_some()
&& prompt_guard_resp.jailbreak_verdict.unwrap()
{
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
//TODO: handle other scenarios like forward to error target
let default_err = "Jailbreak detected. Please refrain from discussing jailbreaking.";
let error_msg = match self.prompt_guards.as_ref() {
Some(prompt_guards) => match prompt_guards
.input_guards
.get(&public_types::configuration::GuardType::Jailbreak)
{
Some(jailbreak) => match jailbreak.on_exception.as_ref() {
Some(on_exception_details) => match on_exception_details.message.as_ref() {
Some(error_msg) => error_msg,
None => default_err,
},
None => default_err,
},
None => default_err,
},
None => default_err,
};
return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST));
let msg = (*self.prompt_guards)
.as_ref()
.and_then(|pg| pg.jailbreak_on_exception_message())
.unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking.");
return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST));
}
self.get_embeddings(callout_context);
@ -900,11 +895,7 @@ impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
let provider_hint = self
.get_http_request_header("x-arch-deterministic-provider")
.is_some();
self.llm_provider = Some(routing::get_llm_provider(provider_hint));
self.select_llm_provider();
self.add_routing_header();
if let Err(error) = self.modify_auth_headers() {
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
@ -959,7 +950,7 @@ impl HttpContext for StreamContext {
};
// Set the model based on the chosen LLM Provider
deserialized_body.model = String::from(self.llm_provider().choose_model());
deserialized_body.model = String::from(&self.llm_provider().model);
self.streaming_response = deserialized_body.stream;
if deserialized_body.stream && deserialized_body.stream_options.is_none() {

View file

@ -29,31 +29,18 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
.call_proxy_on_request_headers(http_context, 0, false)
.expect_get_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-deterministic-provider"),
Some("x-arch-llm-provider-hint"),
)
.returning(Some("true"))
.returning(Some("default"))
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider"),
Some("openai"),
Some("open-ai-gpt-4"),
)
.expect_get_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-openai-api-key"),
)
.returning(Some("api-key"))
.expect_replace_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("Authorization"),
Some("Bearer api-key"),
)
.expect_remove_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-openai-api-key"),
)
.expect_remove_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-mistral-api-key"),
Some("Bearer secret_key"),
)
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
.expect_get_header_map_value(
@ -190,7 +177,8 @@ endpoints:
llm_providers:
- name: open-ai-gpt-4
access_key: $OPEN_AI_API_KEY
provider: openai
access_key: secret_key
model: gpt-4
default: true
@ -240,7 +228,7 @@ prompt_targets:
You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries:
- Use policy number to retrieve insurance claim details
ratelimits:
- provider: gpt-3.5-turbo
- model: gpt-4
selector:
key: selector-key
value: selector-value
@ -267,20 +255,28 @@ fn successful_request_to_open_ai_chat_completions() {
.unwrap();
// Setup Filter
let root_context = 1;
let filter_context = 1;
let config = serde_json::to_string(&default_config()).unwrap();
module
.call_proxy_on_context_create(root_context, 0)
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(&config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
// Setup HTTP Stream
let http_context = 2;
module
.call_proxy_on_context_create(http_context, root_context)
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -336,20 +332,28 @@ fn bad_request_to_open_ai_chat_completions() {
.unwrap();
// Setup Filter
let root_context = 1;
let filter_context = 1;
let config = serde_json::to_string(&default_config()).unwrap();
module
.call_proxy_on_context_create(root_context, 0)
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(&config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
// Setup HTTP Stream
let http_context = 2;
module
.call_proxy_on_context_create(http_context, root_context)
.call_proxy_on_context_create(http_context, filter_context)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -416,7 +420,6 @@ fn request_ratelimited() {
.unwrap();
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(&config))
.execute_and_expect(ReturnType::Bool(true))
@ -531,7 +534,6 @@ fn request_not_ratelimited() {
.unwrap();
module
.call_proxy_on_configure(filter_context, config_str.len() as i32)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(&config_str))
.execute_and_expect(ReturnType::Bool(true))

View file

@ -11,21 +11,24 @@ endpoints:
endpoint: api_server:80
connect_timeout: 0.005s
llm_providers:
- name: open-ai-gpt-4
access_key: $OPEN_AI_API_KEY
model: gpt-4
default: true
overrides:
# confidence threshold for prompt target intent matching
prompt_target_intent_matching_threshold: 0.6
system_prompt: |
You are a helpful assistant.
llm_providers:
- name: open-ai-gpt-4
access_key: $OPENAI_ACCESS_KEY
provider: openai
model: gpt-4
default: true
- name: mistral-large-latest
access_key: $MISTRAL_ACCESS_KEY
provider: mistral
model: large-latest
system_prompt: You are a helpful assistant.
prompt_targets:
- name: weather_forecast
description: This function provides realtime weather forecast information for a given city.
parameters:
@ -78,7 +81,7 @@ prompt_targets:
auto_llm_dispatch_on_response: true
ratelimits:
- provider: gpt-3.5-turbo
- model: gpt-4
selector:
key: selector-key
value: selector-value

View file

@ -24,6 +24,8 @@ services:
condition: service_healthy
environment:
- LOG_LEVEL=debug
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
- MISTRAL_API_KEY=${MISTRAL_API_KEY:?error}
model_server:
build:

View file

@ -31,6 +31,7 @@ endpoints:
# Centralized way to manage LLMs, manage keys, retry logic, failover and limits in a central way
llm_providers:
- name: "OpenAI"
provider: "openai"
access_key: $OPENAI_API_KEY
model: gpt-4o
default: true
@ -45,10 +46,12 @@ llm_providers:
unit: "minute"
- name: "Mistral8x7b"
provider: "mistral"
access_key: $MISTRAL_API_KEY
model: "mistral-8x7b"
- name: "MistralLocal7b"
provider: "local"
model: "mistral-7b-instruct"
endpoint: "mistral_local"

View file

@ -1,7 +1,7 @@
use std::{collections::HashMap, time::Duration};
use duration_string::DurationString;
use serde::{Deserialize, Deserializer, Serialize};
use std::fmt::Display;
use std::{collections::HashMap, time::Duration};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Overrides {
@ -59,6 +59,19 @@ pub struct PromptGuards {
pub input_guards: HashMap<GuardType, GuardOptions>,
}
impl PromptGuards {
pub fn jailbreak_on_exception_message(&self) -> Option<&str> {
self.input_guards
.get(&GuardType::Jailbreak)?
.on_exception
.as_ref()?
.message
.as_ref()?
.as_str()
.into()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum GuardType {
#[serde(rename = "jailbreak")]
@ -96,7 +109,7 @@ pub struct Header {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Ratelimit {
pub provider: String,
pub model: String,
pub selector: Header,
pub limit: Limit,
}
@ -134,7 +147,7 @@ pub struct EmbeddingProviver {
//TODO: use enum for model, but if there is a new model, we need to update the code
pub struct LlmProvider {
pub name: String,
//TODO: handle env var replacement
pub provider: String,
pub access_key: Option<String>,
pub model: String,
pub default: Option<bool>,
@ -142,6 +155,12 @@ pub struct LlmProvider {
pub rate_limits: Option<LlmRatelimit>,
}
impl Display for LlmProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
pub endpoint: Option<String>,