Add the ability to use LLM Providers from the Arch config (#112)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-10-03 10:57:01 -07:00 committed by GitHub
parent 1b57a49c9d
commit 8ea917aae5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 295 additions and 210 deletions

View file

@ -10,3 +10,4 @@ pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";

View file

@ -1,4 +1,5 @@
use crate::consts::{DEFAULT_EMBEDDING_MODEL, MODEL_SERVER_NAME};
use crate::llm_providers::LlmProviders;
use crate::ratelimit;
use crate::stats::{Counter, Gauge, RecordingMetric};
use crate::stream_context::StreamContext;
@ -44,10 +45,11 @@ pub struct FilterContext {
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, CallContext>,
config: Option<Configuration>,
overrides: Rc<Option<Overrides>>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
// This should be Option<Rc<PromptGuards>>, because StreamContext::new() should get an Rc<PromptGuards> not Option<Rc<PromptGuards>>.
prompt_guards: Rc<Option<PromptGuards>>,
llm_providers: Option<Rc<LlmProviders>>,
}
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
@ -62,11 +64,11 @@ impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: HashMap::new(),
config: None,
metrics: Rc::new(WasmMetrics::new()),
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
overrides: Rc::new(None),
prompt_guards: Rc::new(Some(PromptGuards::default())),
llm_providers: None,
}
}
@ -219,42 +221,35 @@ impl Context for FilterContext {
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
let config_bytes = self
.get_plugin_configuration()
.expect("Arch config cannot be empty");
if let Some(overrides_config) = self
.config
.as_mut()
.and_then(|config| config.overrides.as_mut())
{
self.overrides = Rc::new(Some(std::mem::take(overrides_config)));
}
let config: Configuration = match serde_yaml::from_slice(&config_bytes) {
Ok(config) => config,
Err(err) => panic!("Invalid arch config \"{:?}\"", err),
};
for pt in self.config.clone().unwrap().prompt_targets {
self.prompt_targets
.write()
.unwrap()
.insert(pt.name.clone(), pt.clone());
}
self.overrides = Rc::new(config.overrides);
debug!("set configuration object");
if let Some(ratelimits_config) = self
.config
.as_mut()
.and_then(|config| config.ratelimits.as_mut())
{
ratelimit::ratelimits(Some(std::mem::take(ratelimits_config)));
}
if let Some(prompt_guards) = self
.config
.as_mut()
.and_then(|config| config.prompt_guards.as_mut())
{
self.prompt_guards = Rc::new(Some(std::mem::take(prompt_guards)));
}
for pt in config.prompt_targets {
self.prompt_targets
.write()
.unwrap()
.insert(pt.name.clone(), pt.clone());
}
ratelimit::ratelimits(config.ratelimits);
if let Some(prompt_guards) = config.prompt_guards {
self.prompt_guards = Rc::new(Some(prompt_guards))
}
match config.llm_providers.try_into() {
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
Err(err) => panic!("{err}"),
}
true
}
@ -269,6 +264,11 @@ impl RootContext for FilterContext {
Rc::clone(&self.prompt_targets),
Rc::clone(&self.prompt_guards),
Rc::clone(&self.overrides),
Rc::clone(
self.llm_providers
.as_ref()
.expect("LLM Providers must exist when Streams are being created"),
),
)))
}

View file

@ -1,47 +1,69 @@
#[non_exhaustive]
pub struct LlmProviders;
use public_types::configuration::LlmProvider;
use std::collections::HashMap;
use std::rc::Rc;
#[derive(Debug)]
pub struct LlmProviders {
providers: HashMap<String, Rc<LlmProvider>>,
default: Option<Rc<LlmProvider>>,
}
impl LlmProviders {
pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider {
name: "openai",
api_key_header: "x-arch-openai-api-key",
model: "gpt-3.5-turbo",
};
pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider {
name: "mistral",
api_key_header: "x-arch-mistral-api-key",
model: "mistral-large-latest",
};
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc<LlmProvider>> {
self.providers.iter()
}
pub const VARIANTS: &'static [LlmProvider<'static>] =
&[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER];
}
pub fn default(&self) -> Option<Rc<LlmProvider>> {
self.default.as_ref().map(|rc| rc.clone())
}
pub struct LlmProvider<'prov> {
name: &'prov str,
api_key_header: &'prov str,
model: &'prov str,
}
impl AsRef<str> for LlmProvider<'_> {
fn as_ref(&self) -> &str {
self.name
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
self.providers.get(name).map(|rc| rc.clone())
}
}
impl std::fmt::Display for LlmProvider<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
#[derive(thiserror::Error, Debug)]
pub enum LlmProvidersNewError {
#[error("There must be at least one LLM Provider")]
EmptySource,
#[error("There must be at most one default LLM Provider")]
MoreThanOneDefault,
#[error("\'{0}\' is not a unique name")]
DuplicateName(String),
}
impl LlmProvider<'_> {
pub fn api_key_header(&self) -> &str {
self.api_key_header
}
impl TryFrom<Vec<LlmProvider>> for LlmProviders {
type Error = LlmProvidersNewError;
pub fn choose_model(&self) -> &str {
// In the future this can be a more complex function balancing reliability, cost, performance, etc.
self.model
fn try_from(llm_providers_config: Vec<LlmProvider>) -> Result<Self, Self::Error> {
if llm_providers_config.is_empty() {
return Err(LlmProvidersNewError::EmptySource);
}
let mut llm_providers = LlmProviders {
providers: HashMap::new(),
default: None,
};
for llm_provider in llm_providers_config {
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
if llm_provider.default.unwrap_or_default() {
match llm_providers.default {
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
None => llm_providers.default = Some(Rc::clone(&llm_provider)),
}
}
// Insert and check that there is no other provider with the same name.
let name = llm_provider.name.clone();
if llm_providers
.providers
.insert(name.clone(), llm_provider)
.is_some()
{
return Err(LlmProvidersNewError::DuplicateName(name));
}
}
Ok(llm_providers)
}
}

View file

@ -54,10 +54,7 @@ impl RatelimitMap {
for ratelimit_config in ratelimits_config {
let limit = DefaultKeyedRateLimiter::keyed(get_quota(ratelimit_config.limit));
match new_ratelimit_map
.datastore
.get_mut(&ratelimit_config.provider)
{
match new_ratelimit_map.datastore.get_mut(&ratelimit_config.model) {
Some(limits) => match limits.get_mut(&ratelimit_config.selector) {
Some(_) => {
panic!("repeated selector. Selectors per provider must be unique")
@ -72,7 +69,7 @@ impl RatelimitMap {
let new_hash_map = HashMap::from([(ratelimit_config.selector, limit)]);
new_ratelimit_map
.datastore
.insert(ratelimit_config.provider, new_hash_map);
.insert(ratelimit_config.model, new_hash_map);
}
}
}
@ -142,7 +139,7 @@ fn get_quota(limit: Limit) -> Quota {
#[test]
fn non_existent_provider_is_ok() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
model: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
@ -170,7 +167,7 @@ fn non_existent_provider_is_ok() {
#[test]
fn non_existent_key_is_ok() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
model: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
@ -198,7 +195,7 @@ fn non_existent_key_is_ok() {
#[test]
fn specific_limit_does_not_catch_non_specific_value() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
model: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
@ -226,7 +223,7 @@ fn specific_limit_does_not_catch_non_specific_value() {
#[test]
fn specific_limit_is_hit() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
model: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
@ -254,7 +251,7 @@ fn specific_limit_is_hit() {
#[test]
fn non_specific_key_has_different_limits_for_different_values() {
let ratelimits_config = vec![Ratelimit {
provider: String::from("provider"),
model: String::from("provider"),
selector: configuration::Header {
key: String::from("only-key"),
value: None,
@ -308,7 +305,7 @@ fn non_specific_key_has_different_limits_for_different_values() {
fn different_provider_can_have_different_limits_with_the_same_keys() {
let ratelimits_config = vec![
Ratelimit {
provider: String::from("first_provider"),
model: String::from("first_provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
@ -319,7 +316,7 @@ fn different_provider_can_have_different_limits_with_the_same_keys() {
},
},
Ratelimit {
provider: String::from("second_provider"),
model: String::from("second_provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),
@ -391,7 +388,7 @@ mod test {
#[test]
fn different_threads_have_same_ratelimit_data_structure() {
let ratelimits_config = Some(vec![Ratelimit {
provider: String::from("provider"),
model: String::from("provider"),
selector: configuration::Header {
key: String::from("key"),
value: Some(String::from("value")),

View file

@ -1,13 +1,42 @@
use crate::llm_providers::{LlmProvider, LlmProviders};
use rand::{seq::SliceRandom, thread_rng};
use std::rc::Rc;
pub fn get_llm_provider<'hostname>(deterministic: bool) -> &'static LlmProvider<'hostname> {
if deterministic {
&LlmProviders::OPENAI_PROVIDER
} else {
let mut rng = thread_rng();
LlmProviders::VARIANTS
.choose(&mut rng)
.expect("There should always be at least one llm provider")
use crate::llm_providers::LlmProviders;
use public_types::configuration::LlmProvider;
use rand::{seq::IteratorRandom, thread_rng};
pub enum ProviderHint {
Default,
Name(String),
}
impl From<String> for ProviderHint {
fn from(value: String) -> Self {
match value.as_str() {
"default" => ProviderHint::Default,
_ => ProviderHint::Name(value),
}
}
}
pub fn get_llm_provider(
llm_providers: &LlmProviders,
provider_hint: Option<ProviderHint>,
) -> Rc<LlmProvider> {
let maybe_provider = provider_hint.and_then(|hint| match hint {
ProviderHint::Default => llm_providers.default(),
// FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST?
ProviderHint::Name(name) => llm_providers.get(&name),
});
if let Some(provider) = maybe_provider {
return provider;
}
let mut rng = thread_rng();
llm_providers
.iter()
.choose(&mut rng)
.expect("There should always be at least one llm provider")
.1
.clone()
}

View file

@ -1,14 +1,13 @@
use crate::consts::{
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_ROUTING_HEADER, ARC_FC_CLUSTER,
DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO,
MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
ARC_FC_CLUSTER, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD,
GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{embeddings_store, WasmMetrics};
use crate::llm_providers::{LlmProvider, LlmProviders};
use crate::llm_providers::LlmProviders;
use crate::ratelimit::Header;
use crate::stats::IncrementingMetric;
use crate::tokenizer;
use crate::{ratelimit, routing};
use crate::{ratelimit, routing, tokenizer};
use acap::cos;
use http::StatusCode;
use log::{debug, info, warn};
@ -23,6 +22,7 @@ use public_types::common_types::{
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
};
use public_types::configuration::LlmProvider;
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -53,16 +53,17 @@ pub struct CallContext {
}
pub struct StreamContext {
pub context_id: u32,
pub metrics: Rc<WasmMetrics>,
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
pub overrides: Rc<Option<Overrides>>,
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
overrides: Rc<Option<Overrides>>,
callouts: HashMap<u32, CallContext>,
ratelimit_selector: Option<Header>,
streaming_response: bool,
response_tokens: usize,
chat_completions_request: bool,
llm_provider: Option<&'static LlmProvider<'static>>,
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
prompt_guards: Rc<Option<PromptGuards>>,
}
@ -73,6 +74,7 @@ impl StreamContext {
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
prompt_guards: Rc<Option<PromptGuards>>,
overrides: Rc<Option<Overrides>>,
llm_providers: Rc<LlmProviders>,
) -> Self {
StreamContext {
context_id,
@ -83,6 +85,7 @@ impl StreamContext {
streaming_response: false,
response_tokens: 0,
chat_completions_request: false,
llm_providers,
llm_provider: None,
prompt_guards,
overrides,
@ -90,27 +93,35 @@ impl StreamContext {
}
fn llm_provider(&self) -> &LlmProvider {
self.llm_provider
.as_ref()
.expect("the provider should be set when asked for it")
}
fn select_llm_provider(&mut self) {
let provider_hint = self
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.map(|provider_name| provider_name.into());
self.llm_provider = Some(routing::get_llm_provider(
&self.llm_providers,
provider_hint,
));
}
fn add_routing_header(&mut self) {
self.add_http_request_header(ARCH_ROUTING_HEADER, self.llm_provider().as_ref());
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
}
fn modify_auth_headers(&mut self) -> Result<(), String> {
let llm_provider_api_key_value = self
.get_http_request_header(self.llm_provider().api_key_header())
.ok_or(format!("missing {} api key", self.llm_provider()))?;
let llm_provider_api_key_value = self.llm_provider().access_key.as_ref().ok_or(format!(
"No access key configured for selected LLM Provider \"{}\"",
self.llm_provider()
))?;
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
self.set_http_request_header("Authorization", Some(&authorization_header_value));
// sanitize passed in api keys
for provider in LlmProviders::VARIANTS.iter() {
self.set_http_request_header(provider.api_key_header(), None);
}
Ok(())
}
@ -728,29 +739,13 @@ impl StreamContext {
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
if prompt_guard_resp.jailbreak_verdict.is_some()
&& prompt_guard_resp.jailbreak_verdict.unwrap()
{
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
//TODO: handle other scenarios like forward to error target
let default_err = "Jailbreak detected. Please refrain from discussing jailbreaking.";
let error_msg = match self.prompt_guards.as_ref() {
Some(prompt_guards) => match prompt_guards
.input_guards
.get(&public_types::configuration::GuardType::Jailbreak)
{
Some(jailbreak) => match jailbreak.on_exception.as_ref() {
Some(on_exception_details) => match on_exception_details.message.as_ref() {
Some(error_msg) => error_msg,
None => default_err,
},
None => default_err,
},
None => default_err,
},
None => default_err,
};
return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST));
let msg = (*self.prompt_guards)
.as_ref()
.and_then(|pg| pg.jailbreak_on_exception_message())
.unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking.");
return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST));
}
self.get_embeddings(callout_context);
@ -900,11 +895,7 @@ impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
let provider_hint = self
.get_http_request_header("x-arch-deterministic-provider")
.is_some();
self.llm_provider = Some(routing::get_llm_provider(provider_hint));
self.select_llm_provider();
self.add_routing_header();
if let Err(error) = self.modify_auth_headers() {
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
@ -959,7 +950,7 @@ impl HttpContext for StreamContext {
};
// Set the model based on the chosen LLM Provider
deserialized_body.model = String::from(self.llm_provider().choose_model());
deserialized_body.model = String::from(&self.llm_provider().model);
self.streaming_response = deserialized_body.stream;
if deserialized_body.stream && deserialized_body.stream_options.is_none() {