mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
Add the ability to use LLM Providers from the Arch config (#112)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
1b57a49c9d
commit
8ea917aae5
16 changed files with 295 additions and 210 deletions
11
arch/Cargo.lock
generated
11
arch/Cargo.lock
generated
|
|
@ -759,6 +759,7 @@ dependencies = [
|
|||
"serde_json",
|
||||
"serde_yaml",
|
||||
"serial_test",
|
||||
"thiserror",
|
||||
"tiktoken-rs",
|
||||
]
|
||||
|
||||
|
|
@ -1060,7 +1061,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "proxy-wasm-test-framework"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/katanemo/test-framework.git?branch=main#c2511cd9030705e14d5f60aca77d6c96c81c6dfa"
|
||||
source = "git+https://github.com/katanemo/test-framework.git?branch=new#c2511cd9030705e14d5f60aca77d6c96c81c6dfa"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"cfg-if 0.1.10",
|
||||
|
|
@ -1490,18 +1491,18 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.63"
|
||||
version = "1.0.64"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
|
||||
checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.63"
|
||||
version = "1.0.64"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
|
||||
checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
|
|
|||
|
|
@ -20,7 +20,8 @@ governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
|
|||
tiktoken-rs = "0.5.9"
|
||||
acap = "0.3.0"
|
||||
rand = "0.8.5"
|
||||
thiserror = "1.0.64"
|
||||
|
||||
[dev-dependencies]
|
||||
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" }
|
||||
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
|
||||
serial_test = "3.1.1"
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ properties:
|
|||
properties:
|
||||
name:
|
||||
type: string
|
||||
provider:
|
||||
type: string
|
||||
access_key:
|
||||
type: string
|
||||
model:
|
||||
|
|
@ -47,6 +49,7 @@ properties:
|
|||
additionalProperties: false
|
||||
required:
|
||||
- name
|
||||
- provider
|
||||
- access_key
|
||||
- model
|
||||
overrides:
|
||||
|
|
@ -112,7 +115,7 @@ properties:
|
|||
items:
|
||||
type: object
|
||||
properties:
|
||||
provider:
|
||||
model:
|
||||
type: string
|
||||
selector:
|
||||
type: object
|
||||
|
|
@ -138,7 +141,7 @@ properties:
|
|||
- unit
|
||||
additionalProperties: false
|
||||
required:
|
||||
- provider
|
||||
- model
|
||||
- selector
|
||||
- limit
|
||||
additionalProperties: false
|
||||
|
|
|
|||
|
|
@ -7,17 +7,32 @@ ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.temp
|
|||
ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/config/arch_config.yaml')
|
||||
ENVOY_CONFIG_FILE_RENDERED = os.getenv('ENVOY_CONFIG_FILE_RENDERED', '/etc/envoy/envoy.yaml')
|
||||
ARCH_CONFIG_SCHEMA_FILE = os.getenv('ARCH_CONFIG_SCHEMA_FILE', 'arch_config_schema.yaml')
|
||||
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', False)
|
||||
MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY', False)
|
||||
|
||||
def add_secret_key_to_llm_providers(config_yaml) :
|
||||
llm_providers = []
|
||||
for llm_provider in config_yaml.get("llm_providers", []):
|
||||
if llm_provider['access_key'] == "$MISTRAL_ACCESS_KEY":
|
||||
llm_provider['access_key'] = MISTRAL_API_KEY
|
||||
elif llm_provider['access_key'] == "$OPENAI_ACCESS_KEY":
|
||||
llm_provider['access_key'] = OPENAI_API_KEY
|
||||
else:
|
||||
llm_provider.pop('access_key')
|
||||
llm_providers.append(llm_provider)
|
||||
config_yaml["llm_providers"] = llm_providers
|
||||
return config_yaml
|
||||
|
||||
env = Environment(loader=FileSystemLoader('./'))
|
||||
template = env.get_template('envoy.template.yaml')
|
||||
|
||||
with open(ARCH_CONFIG_FILE, 'r') as file:
|
||||
katanemo_config = file.read()
|
||||
arch_config_string = file.read()
|
||||
|
||||
with open(ARCH_CONFIG_SCHEMA_FILE, 'r') as file:
|
||||
arch_config_schema = file.read()
|
||||
|
||||
config_yaml = yaml.safe_load(katanemo_config)
|
||||
config_yaml = yaml.safe_load(arch_config_string)
|
||||
config_schema_yaml = yaml.safe_load(arch_config_schema)
|
||||
|
||||
try:
|
||||
|
|
@ -54,9 +69,16 @@ for name, endpoint_details in endpoints.items():
|
|||
|
||||
print("updated clusters", inferred_clusters)
|
||||
|
||||
config_yaml = add_secret_key_to_llm_providers(config_yaml)
|
||||
arch_llm_providers = config_yaml["llm_providers"]
|
||||
arch_config_string = yaml.dump(config_yaml)
|
||||
|
||||
print("llm_providers:", arch_llm_providers)
|
||||
|
||||
data = {
|
||||
'katanemo_config': katanemo_config,
|
||||
'arch_clusters': inferred_clusters
|
||||
'arch_config': arch_config_string,
|
||||
'arch_clusters': inferred_clusters,
|
||||
'arch_llm_providers': arch_llm_providers
|
||||
}
|
||||
|
||||
rendered = template.render(data)
|
||||
|
|
|
|||
|
|
@ -34,26 +34,18 @@ static_resources:
|
|||
auto_host_rewrite: true
|
||||
cluster: mistral_7b_instruct
|
||||
timeout: 60s
|
||||
{% for provider in arch_llm_providers %}
|
||||
- match:
|
||||
prefix: "/v1/chat/completions"
|
||||
prefix: "/"
|
||||
headers:
|
||||
- name: "x-arch-llm-provider"
|
||||
string_match:
|
||||
exact: openai
|
||||
exact: {{ provider.name }}
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: openai
|
||||
timeout: 60s
|
||||
- match:
|
||||
prefix: "/v1/chat/completions"
|
||||
headers:
|
||||
- name: "x-arch-llm-provider"
|
||||
string_match:
|
||||
exact: mistral
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: mistral
|
||||
cluster: {{ provider.provider }}
|
||||
timeout: 60s
|
||||
{% endfor %}
|
||||
http_filters:
|
||||
- name: envoy.filters.http.wasm
|
||||
typed_config:
|
||||
|
|
@ -65,7 +57,7 @@ static_resources:
|
|||
configuration:
|
||||
"@type": "type.googleapis.com/google.protobuf.StringValue"
|
||||
value: |
|
||||
{{ katanemo_config | indent(30) }}
|
||||
{{ arch_config | indent(30) }}
|
||||
vm_config:
|
||||
runtime: "envoy.wasm.runtime.v8"
|
||||
code:
|
||||
|
|
@ -75,9 +67,6 @@ static_resources:
|
|||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
|
||||
clusters:
|
||||
# LLM Host
|
||||
# Embedding Providers
|
||||
# External LLM Providers
|
||||
- name: openai
|
||||
connect_timeout: 5s
|
||||
dns_lookup_family: V4_ONLY
|
||||
|
|
|
|||
|
|
@ -10,3 +10,4 @@ pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
|||
pub const MODEL_SERVER_NAME: &str = "model_server";
|
||||
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
||||
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
use crate::consts::{DEFAULT_EMBEDDING_MODEL, MODEL_SERVER_NAME};
|
||||
use crate::llm_providers::LlmProviders;
|
||||
use crate::ratelimit;
|
||||
use crate::stats::{Counter, Gauge, RecordingMetric};
|
||||
use crate::stream_context::StreamContext;
|
||||
|
|
@ -44,10 +45,11 @@ pub struct FilterContext {
|
|||
metrics: Rc<WasmMetrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
config: Option<Configuration>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
// This should be Option<Rc<PromptGuards>>, because StreamContext::new() should get an Rc<PromptGuards> not Option<Rc<PromptGuards>>.
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
llm_providers: Option<Rc<LlmProviders>>,
|
||||
}
|
||||
|
||||
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
|
||||
|
|
@ -62,11 +64,11 @@ impl FilterContext {
|
|||
pub fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
metrics: Rc::new(WasmMetrics::new()),
|
||||
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(Some(PromptGuards::default())),
|
||||
llm_providers: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -219,42 +221,35 @@ impl Context for FilterContext {
|
|||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
|
||||
let config_bytes = self
|
||||
.get_plugin_configuration()
|
||||
.expect("Arch config cannot be empty");
|
||||
|
||||
if let Some(overrides_config) = self
|
||||
.config
|
||||
.as_mut()
|
||||
.and_then(|config| config.overrides.as_mut())
|
||||
{
|
||||
self.overrides = Rc::new(Some(std::mem::take(overrides_config)));
|
||||
}
|
||||
let config: Configuration = match serde_yaml::from_slice(&config_bytes) {
|
||||
Ok(config) => config,
|
||||
Err(err) => panic!("Invalid arch config \"{:?}\"", err),
|
||||
};
|
||||
|
||||
for pt in self.config.clone().unwrap().prompt_targets {
|
||||
self.prompt_targets
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(pt.name.clone(), pt.clone());
|
||||
}
|
||||
self.overrides = Rc::new(config.overrides);
|
||||
|
||||
debug!("set configuration object");
|
||||
|
||||
if let Some(ratelimits_config) = self
|
||||
.config
|
||||
.as_mut()
|
||||
.and_then(|config| config.ratelimits.as_mut())
|
||||
{
|
||||
ratelimit::ratelimits(Some(std::mem::take(ratelimits_config)));
|
||||
}
|
||||
|
||||
if let Some(prompt_guards) = self
|
||||
.config
|
||||
.as_mut()
|
||||
.and_then(|config| config.prompt_guards.as_mut())
|
||||
{
|
||||
self.prompt_guards = Rc::new(Some(std::mem::take(prompt_guards)));
|
||||
}
|
||||
for pt in config.prompt_targets {
|
||||
self.prompt_targets
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(pt.name.clone(), pt.clone());
|
||||
}
|
||||
|
||||
ratelimit::ratelimits(config.ratelimits);
|
||||
|
||||
if let Some(prompt_guards) = config.prompt_guards {
|
||||
self.prompt_guards = Rc::new(Some(prompt_guards))
|
||||
}
|
||||
|
||||
match config.llm_providers.try_into() {
|
||||
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
|
||||
Err(err) => panic!("{err}"),
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
|
|
@ -269,6 +264,11 @@ impl RootContext for FilterContext {
|
|||
Rc::clone(&self.prompt_targets),
|
||||
Rc::clone(&self.prompt_guards),
|
||||
Rc::clone(&self.overrides),
|
||||
Rc::clone(
|
||||
self.llm_providers
|
||||
.as_ref()
|
||||
.expect("LLM Providers must exist when Streams are being created"),
|
||||
),
|
||||
)))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,47 +1,69 @@
|
|||
#[non_exhaustive]
|
||||
pub struct LlmProviders;
|
||||
use public_types::configuration::LlmProvider;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LlmProviders {
|
||||
providers: HashMap<String, Rc<LlmProvider>>,
|
||||
default: Option<Rc<LlmProvider>>,
|
||||
}
|
||||
|
||||
impl LlmProviders {
|
||||
pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider {
|
||||
name: "openai",
|
||||
api_key_header: "x-arch-openai-api-key",
|
||||
model: "gpt-3.5-turbo",
|
||||
};
|
||||
pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider {
|
||||
name: "mistral",
|
||||
api_key_header: "x-arch-mistral-api-key",
|
||||
model: "mistral-large-latest",
|
||||
};
|
||||
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc<LlmProvider>> {
|
||||
self.providers.iter()
|
||||
}
|
||||
|
||||
pub const VARIANTS: &'static [LlmProvider<'static>] =
|
||||
&[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER];
|
||||
}
|
||||
pub fn default(&self) -> Option<Rc<LlmProvider>> {
|
||||
self.default.as_ref().map(|rc| rc.clone())
|
||||
}
|
||||
|
||||
pub struct LlmProvider<'prov> {
|
||||
name: &'prov str,
|
||||
api_key_header: &'prov str,
|
||||
model: &'prov str,
|
||||
}
|
||||
|
||||
impl AsRef<str> for LlmProvider<'_> {
|
||||
fn as_ref(&self) -> &str {
|
||||
self.name
|
||||
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
|
||||
self.providers.get(name).map(|rc| rc.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LlmProvider<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name)
|
||||
}
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum LlmProvidersNewError {
|
||||
#[error("There must be at least one LLM Provider")]
|
||||
EmptySource,
|
||||
#[error("There must be at most one default LLM Provider")]
|
||||
MoreThanOneDefault,
|
||||
#[error("\'{0}\' is not a unique name")]
|
||||
DuplicateName(String),
|
||||
}
|
||||
|
||||
impl LlmProvider<'_> {
|
||||
pub fn api_key_header(&self) -> &str {
|
||||
self.api_key_header
|
||||
}
|
||||
impl TryFrom<Vec<LlmProvider>> for LlmProviders {
|
||||
type Error = LlmProvidersNewError;
|
||||
|
||||
pub fn choose_model(&self) -> &str {
|
||||
// In the future this can be a more complex function balancing reliability, cost, performance, etc.
|
||||
self.model
|
||||
fn try_from(llm_providers_config: Vec<LlmProvider>) -> Result<Self, Self::Error> {
|
||||
if llm_providers_config.is_empty() {
|
||||
return Err(LlmProvidersNewError::EmptySource);
|
||||
}
|
||||
|
||||
let mut llm_providers = LlmProviders {
|
||||
providers: HashMap::new(),
|
||||
default: None,
|
||||
};
|
||||
|
||||
for llm_provider in llm_providers_config {
|
||||
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
|
||||
if llm_provider.default.unwrap_or_default() {
|
||||
match llm_providers.default {
|
||||
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
|
||||
None => llm_providers.default = Some(Rc::clone(&llm_provider)),
|
||||
}
|
||||
}
|
||||
|
||||
// Insert and check that there is no other provider with the same name.
|
||||
let name = llm_provider.name.clone();
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(name.clone(), llm_provider)
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
}
|
||||
}
|
||||
Ok(llm_providers)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,10 +54,7 @@ impl RatelimitMap {
|
|||
for ratelimit_config in ratelimits_config {
|
||||
let limit = DefaultKeyedRateLimiter::keyed(get_quota(ratelimit_config.limit));
|
||||
|
||||
match new_ratelimit_map
|
||||
.datastore
|
||||
.get_mut(&ratelimit_config.provider)
|
||||
{
|
||||
match new_ratelimit_map.datastore.get_mut(&ratelimit_config.model) {
|
||||
Some(limits) => match limits.get_mut(&ratelimit_config.selector) {
|
||||
Some(_) => {
|
||||
panic!("repeated selector. Selectors per provider must be unique")
|
||||
|
|
@ -72,7 +69,7 @@ impl RatelimitMap {
|
|||
let new_hash_map = HashMap::from([(ratelimit_config.selector, limit)]);
|
||||
new_ratelimit_map
|
||||
.datastore
|
||||
.insert(ratelimit_config.provider, new_hash_map);
|
||||
.insert(ratelimit_config.model, new_hash_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -142,7 +139,7 @@ fn get_quota(limit: Limit) -> Quota {
|
|||
#[test]
|
||||
fn non_existent_provider_is_ok() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
|
|
@ -170,7 +167,7 @@ fn non_existent_provider_is_ok() {
|
|||
#[test]
|
||||
fn non_existent_key_is_ok() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
|
|
@ -198,7 +195,7 @@ fn non_existent_key_is_ok() {
|
|||
#[test]
|
||||
fn specific_limit_does_not_catch_non_specific_value() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
|
|
@ -226,7 +223,7 @@ fn specific_limit_does_not_catch_non_specific_value() {
|
|||
#[test]
|
||||
fn specific_limit_is_hit() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
|
|
@ -254,7 +251,7 @@ fn specific_limit_is_hit() {
|
|||
#[test]
|
||||
fn non_specific_key_has_different_limits_for_different_values() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
|
|
@ -308,7 +305,7 @@ fn non_specific_key_has_different_limits_for_different_values() {
|
|||
fn different_provider_can_have_different_limits_with_the_same_keys() {
|
||||
let ratelimits_config = vec![
|
||||
Ratelimit {
|
||||
provider: String::from("first_provider"),
|
||||
model: String::from("first_provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
|
|
@ -319,7 +316,7 @@ fn different_provider_can_have_different_limits_with_the_same_keys() {
|
|||
},
|
||||
},
|
||||
Ratelimit {
|
||||
provider: String::from("second_provider"),
|
||||
model: String::from("second_provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
|
|
@ -391,7 +388,7 @@ mod test {
|
|||
#[test]
|
||||
fn different_threads_have_same_ratelimit_data_structure() {
|
||||
let ratelimits_config = Some(vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
|
|
|
|||
|
|
@ -1,13 +1,42 @@
|
|||
use crate::llm_providers::{LlmProvider, LlmProviders};
|
||||
use rand::{seq::SliceRandom, thread_rng};
|
||||
use std::rc::Rc;
|
||||
|
||||
pub fn get_llm_provider<'hostname>(deterministic: bool) -> &'static LlmProvider<'hostname> {
|
||||
if deterministic {
|
||||
&LlmProviders::OPENAI_PROVIDER
|
||||
} else {
|
||||
let mut rng = thread_rng();
|
||||
LlmProviders::VARIANTS
|
||||
.choose(&mut rng)
|
||||
.expect("There should always be at least one llm provider")
|
||||
use crate::llm_providers::LlmProviders;
|
||||
use public_types::configuration::LlmProvider;
|
||||
use rand::{seq::IteratorRandom, thread_rng};
|
||||
|
||||
pub enum ProviderHint {
|
||||
Default,
|
||||
Name(String),
|
||||
}
|
||||
|
||||
impl From<String> for ProviderHint {
|
||||
fn from(value: String) -> Self {
|
||||
match value.as_str() {
|
||||
"default" => ProviderHint::Default,
|
||||
_ => ProviderHint::Name(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_llm_provider(
|
||||
llm_providers: &LlmProviders,
|
||||
provider_hint: Option<ProviderHint>,
|
||||
) -> Rc<LlmProvider> {
|
||||
let maybe_provider = provider_hint.and_then(|hint| match hint {
|
||||
ProviderHint::Default => llm_providers.default(),
|
||||
// FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST?
|
||||
ProviderHint::Name(name) => llm_providers.get(&name),
|
||||
});
|
||||
|
||||
if let Some(provider) = maybe_provider {
|
||||
return provider;
|
||||
}
|
||||
|
||||
let mut rng = thread_rng();
|
||||
llm_providers
|
||||
.iter()
|
||||
.choose(&mut rng)
|
||||
.expect("There should always be at least one llm provider")
|
||||
.1
|
||||
.clone()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
use crate::consts::{
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_ROUTING_HEADER, ARC_FC_CLUSTER,
|
||||
DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO,
|
||||
MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
|
||||
ARC_FC_CLUSTER, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD,
|
||||
GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::{embeddings_store, WasmMetrics};
|
||||
use crate::llm_providers::{LlmProvider, LlmProviders};
|
||||
use crate::llm_providers::LlmProviders;
|
||||
use crate::ratelimit::Header;
|
||||
use crate::stats::IncrementingMetric;
|
||||
use crate::tokenizer;
|
||||
use crate::{ratelimit, routing};
|
||||
use crate::{ratelimit, routing, tokenizer};
|
||||
use acap::cos;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
|
|
@ -23,6 +22,7 @@ use public_types::common_types::{
|
|||
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
|
||||
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||
};
|
||||
use public_types::configuration::LlmProvider;
|
||||
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
|
||||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
|
|
@ -53,16 +53,17 @@ pub struct CallContext {
|
|||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
pub context_id: u32,
|
||||
pub metrics: Rc<WasmMetrics>,
|
||||
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
pub overrides: Rc<Option<Overrides>>,
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
ratelimit_selector: Option<Header>,
|
||||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
chat_completions_request: bool,
|
||||
llm_provider: Option<&'static LlmProvider<'static>>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
llm_provider: Option<Rc<LlmProvider>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
}
|
||||
|
||||
|
|
@ -73,6 +74,7 @@ impl StreamContext {
|
|||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
) -> Self {
|
||||
StreamContext {
|
||||
context_id,
|
||||
|
|
@ -83,6 +85,7 @@ impl StreamContext {
|
|||
streaming_response: false,
|
||||
response_tokens: 0,
|
||||
chat_completions_request: false,
|
||||
llm_providers,
|
||||
llm_provider: None,
|
||||
prompt_guards,
|
||||
overrides,
|
||||
|
|
@ -90,27 +93,35 @@ impl StreamContext {
|
|||
}
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
self.llm_provider
|
||||
.as_ref()
|
||||
.expect("the provider should be set when asked for it")
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
let provider_hint = self
|
||||
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.map(|provider_name| provider_name.into());
|
||||
|
||||
self.llm_provider = Some(routing::get_llm_provider(
|
||||
&self.llm_providers,
|
||||
provider_hint,
|
||||
));
|
||||
}
|
||||
|
||||
fn add_routing_header(&mut self) {
|
||||
self.add_http_request_header(ARCH_ROUTING_HEADER, self.llm_provider().as_ref());
|
||||
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
|
||||
}
|
||||
|
||||
fn modify_auth_headers(&mut self) -> Result<(), String> {
|
||||
let llm_provider_api_key_value = self
|
||||
.get_http_request_header(self.llm_provider().api_key_header())
|
||||
.ok_or(format!("missing {} api key", self.llm_provider()))?;
|
||||
let llm_provider_api_key_value = self.llm_provider().access_key.as_ref().ok_or(format!(
|
||||
"No access key configured for selected LLM Provider \"{}\"",
|
||||
self.llm_provider()
|
||||
))?;
|
||||
|
||||
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
|
||||
|
||||
self.set_http_request_header("Authorization", Some(&authorization_header_value));
|
||||
|
||||
// sanitize passed in api keys
|
||||
for provider in LlmProviders::VARIANTS.iter() {
|
||||
self.set_http_request_header(provider.api_key_header(), None);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -728,29 +739,13 @@ impl StreamContext {
|
|||
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
||||
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
|
||||
|
||||
if prompt_guard_resp.jailbreak_verdict.is_some()
|
||||
&& prompt_guard_resp.jailbreak_verdict.unwrap()
|
||||
{
|
||||
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
|
||||
//TODO: handle other scenarios like forward to error target
|
||||
let default_err = "Jailbreak detected. Please refrain from discussing jailbreaking.";
|
||||
let error_msg = match self.prompt_guards.as_ref() {
|
||||
Some(prompt_guards) => match prompt_guards
|
||||
.input_guards
|
||||
.get(&public_types::configuration::GuardType::Jailbreak)
|
||||
{
|
||||
Some(jailbreak) => match jailbreak.on_exception.as_ref() {
|
||||
Some(on_exception_details) => match on_exception_details.message.as_ref() {
|
||||
Some(error_msg) => error_msg,
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
};
|
||||
|
||||
return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST));
|
||||
let msg = (*self.prompt_guards)
|
||||
.as_ref()
|
||||
.and_then(|pg| pg.jailbreak_on_exception_message())
|
||||
.unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking.");
|
||||
return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
|
||||
self.get_embeddings(callout_context);
|
||||
|
|
@ -900,11 +895,7 @@ impl HttpContext for StreamContext {
|
|||
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
|
||||
// the lifecycle of the http request and response.
|
||||
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
let provider_hint = self
|
||||
.get_http_request_header("x-arch-deterministic-provider")
|
||||
.is_some();
|
||||
self.llm_provider = Some(routing::get_llm_provider(provider_hint));
|
||||
|
||||
self.select_llm_provider();
|
||||
self.add_routing_header();
|
||||
if let Err(error) = self.modify_auth_headers() {
|
||||
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
|
||||
|
|
@ -959,7 +950,7 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
// Set the model based on the chosen LLM Provider
|
||||
deserialized_body.model = String::from(self.llm_provider().choose_model());
|
||||
deserialized_body.model = String::from(&self.llm_provider().model);
|
||||
|
||||
self.streaming_response = deserialized_body.stream;
|
||||
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
|
||||
|
|
|
|||
|
|
@ -29,31 +29,18 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
|||
.call_proxy_on_request_headers(http_context, 0, false)
|
||||
.expect_get_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-deterministic-provider"),
|
||||
Some("x-arch-llm-provider-hint"),
|
||||
)
|
||||
.returning(Some("true"))
|
||||
.returning(Some("default"))
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-llm-provider"),
|
||||
Some("openai"),
|
||||
Some("open-ai-gpt-4"),
|
||||
)
|
||||
.expect_get_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-openai-api-key"),
|
||||
)
|
||||
.returning(Some("api-key"))
|
||||
.expect_replace_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("Authorization"),
|
||||
Some("Bearer api-key"),
|
||||
)
|
||||
.expect_remove_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-openai-api-key"),
|
||||
)
|
||||
.expect_remove_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-mistral-api-key"),
|
||||
Some("Bearer secret_key"),
|
||||
)
|
||||
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
|
||||
.expect_get_header_map_value(
|
||||
|
|
@ -190,7 +177,8 @@ endpoints:
|
|||
|
||||
llm_providers:
|
||||
- name: open-ai-gpt-4
|
||||
access_key: $OPEN_AI_API_KEY
|
||||
provider: openai
|
||||
access_key: secret_key
|
||||
model: gpt-4
|
||||
default: true
|
||||
|
||||
|
|
@ -240,7 +228,7 @@ prompt_targets:
|
|||
You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use policy number to retrieve insurance claim details
|
||||
ratelimits:
|
||||
- provider: gpt-3.5-turbo
|
||||
- model: gpt-4
|
||||
selector:
|
||||
key: selector-key
|
||||
value: selector-value
|
||||
|
|
@ -267,20 +255,28 @@ fn successful_request_to_open_ai_chat_completions() {
|
|||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let root_context = 1;
|
||||
let filter_context = 1;
|
||||
let config = serde_json::to_string(&default_config()).unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(root_context, 0)
|
||||
.call_proxy_on_context_create(filter_context, 0)
|
||||
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
|
||||
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_configure(filter_context, config.len() as i32)
|
||||
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
|
||||
.returning(Some(&config))
|
||||
.execute_and_expect(ReturnType::Bool(true))
|
||||
.unwrap();
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, root_context)
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
|
@ -336,20 +332,28 @@ fn bad_request_to_open_ai_chat_completions() {
|
|||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let root_context = 1;
|
||||
let filter_context = 1;
|
||||
let config = serde_json::to_string(&default_config()).unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(root_context, 0)
|
||||
.call_proxy_on_context_create(filter_context, 0)
|
||||
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
|
||||
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_configure(filter_context, config.len() as i32)
|
||||
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
|
||||
.returning(Some(&config))
|
||||
.execute_and_expect(ReturnType::Bool(true))
|
||||
.unwrap();
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, root_context)
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
|
@ -416,7 +420,6 @@ fn request_ratelimited() {
|
|||
.unwrap();
|
||||
module
|
||||
.call_proxy_on_configure(filter_context, config.len() as i32)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
|
||||
.returning(Some(&config))
|
||||
.execute_and_expect(ReturnType::Bool(true))
|
||||
|
|
@ -531,7 +534,6 @@ fn request_not_ratelimited() {
|
|||
.unwrap();
|
||||
module
|
||||
.call_proxy_on_configure(filter_context, config_str.len() as i32)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
|
||||
.returning(Some(&config_str))
|
||||
.execute_and_expect(ReturnType::Bool(true))
|
||||
|
|
|
|||
|
|
@ -11,21 +11,24 @@ endpoints:
|
|||
endpoint: api_server:80
|
||||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: open-ai-gpt-4
|
||||
access_key: $OPEN_AI_API_KEY
|
||||
model: gpt-4
|
||||
default: true
|
||||
|
||||
overrides:
|
||||
# confidence threshold for prompt target intent matching
|
||||
prompt_target_intent_matching_threshold: 0.6
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
llm_providers:
|
||||
- name: open-ai-gpt-4
|
||||
access_key: $OPENAI_ACCESS_KEY
|
||||
provider: openai
|
||||
model: gpt-4
|
||||
default: true
|
||||
- name: mistral-large-latest
|
||||
access_key: $MISTRAL_ACCESS_KEY
|
||||
provider: mistral
|
||||
model: large-latest
|
||||
|
||||
system_prompt: You are a helpful assistant.
|
||||
|
||||
prompt_targets:
|
||||
|
||||
- name: weather_forecast
|
||||
description: This function provides realtime weather forecast information for a given city.
|
||||
parameters:
|
||||
|
|
@ -78,7 +81,7 @@ prompt_targets:
|
|||
auto_llm_dispatch_on_response: true
|
||||
|
||||
ratelimits:
|
||||
- provider: gpt-3.5-turbo
|
||||
- model: gpt-4
|
||||
selector:
|
||||
key: selector-key
|
||||
value: selector-value
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ services:
|
|||
condition: service_healthy
|
||||
environment:
|
||||
- LOG_LEVEL=debug
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
|
||||
- MISTRAL_API_KEY=${MISTRAL_API_KEY:?error}
|
||||
|
||||
model_server:
|
||||
build:
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ endpoints:
|
|||
# Centralized way to manage LLMs, manage keys, retry logic, failover and limits in a central way
|
||||
llm_providers:
|
||||
- name: "OpenAI"
|
||||
provider: "openai"
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
default: true
|
||||
|
|
@ -45,10 +46,12 @@ llm_providers:
|
|||
unit: "minute"
|
||||
|
||||
- name: "Mistral8x7b"
|
||||
provider: "mistral"
|
||||
access_key: $MISTRAL_API_KEY
|
||||
model: "mistral-8x7b"
|
||||
|
||||
- name: "MistralLocal7b"
|
||||
provider: "local"
|
||||
model: "mistral-7b-instruct"
|
||||
endpoint: "mistral_local"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use std::{collections::HashMap, time::Duration};
|
||||
|
||||
use duration_string::DurationString;
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use std::fmt::Display;
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Overrides {
|
||||
|
|
@ -59,6 +59,19 @@ pub struct PromptGuards {
|
|||
pub input_guards: HashMap<GuardType, GuardOptions>,
|
||||
}
|
||||
|
||||
impl PromptGuards {
|
||||
pub fn jailbreak_on_exception_message(&self) -> Option<&str> {
|
||||
self.input_guards
|
||||
.get(&GuardType::Jailbreak)?
|
||||
.on_exception
|
||||
.as_ref()?
|
||||
.message
|
||||
.as_ref()?
|
||||
.as_str()
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub enum GuardType {
|
||||
#[serde(rename = "jailbreak")]
|
||||
|
|
@ -96,7 +109,7 @@ pub struct Header {
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Ratelimit {
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
pub selector: Header,
|
||||
pub limit: Limit,
|
||||
}
|
||||
|
|
@ -134,7 +147,7 @@ pub struct EmbeddingProviver {
|
|||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct LlmProvider {
|
||||
pub name: String,
|
||||
//TODO: handle env var replacement
|
||||
pub provider: String,
|
||||
pub access_key: Option<String>,
|
||||
pub model: String,
|
||||
pub default: Option<bool>,
|
||||
|
|
@ -142,6 +155,12 @@ pub struct LlmProvider {
|
|||
pub rate_limits: Option<LlmRatelimit>,
|
||||
}
|
||||
|
||||
impl Display for LlmProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Endpoint {
|
||||
pub endpoint: Option<String>,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue