mirror of
https://github.com/katanemo/plano.git
synced 2026-05-04 13:23:00 +02:00
Add the ability to use LLM Providers from the Arch config (#112)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
1b57a49c9d
commit
8ea917aae5
16 changed files with 295 additions and 210 deletions
|
|
@ -10,3 +10,4 @@ pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
|||
pub const MODEL_SERVER_NAME: &str = "model_server";
|
||||
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
||||
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
use crate::consts::{DEFAULT_EMBEDDING_MODEL, MODEL_SERVER_NAME};
|
||||
use crate::llm_providers::LlmProviders;
|
||||
use crate::ratelimit;
|
||||
use crate::stats::{Counter, Gauge, RecordingMetric};
|
||||
use crate::stream_context::StreamContext;
|
||||
|
|
@ -44,10 +45,11 @@ pub struct FilterContext {
|
|||
metrics: Rc<WasmMetrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
config: Option<Configuration>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
// This should be Option<Rc<PromptGuards>>, because StreamContext::new() should get an Rc<PromptGuards> not Option<Rc<PromptGuards>>.
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
llm_providers: Option<Rc<LlmProviders>>,
|
||||
}
|
||||
|
||||
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
|
||||
|
|
@ -62,11 +64,11 @@ impl FilterContext {
|
|||
pub fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
metrics: Rc::new(WasmMetrics::new()),
|
||||
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(Some(PromptGuards::default())),
|
||||
llm_providers: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -219,42 +221,35 @@ impl Context for FilterContext {
|
|||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
|
||||
let config_bytes = self
|
||||
.get_plugin_configuration()
|
||||
.expect("Arch config cannot be empty");
|
||||
|
||||
if let Some(overrides_config) = self
|
||||
.config
|
||||
.as_mut()
|
||||
.and_then(|config| config.overrides.as_mut())
|
||||
{
|
||||
self.overrides = Rc::new(Some(std::mem::take(overrides_config)));
|
||||
}
|
||||
let config: Configuration = match serde_yaml::from_slice(&config_bytes) {
|
||||
Ok(config) => config,
|
||||
Err(err) => panic!("Invalid arch config \"{:?}\"", err),
|
||||
};
|
||||
|
||||
for pt in self.config.clone().unwrap().prompt_targets {
|
||||
self.prompt_targets
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(pt.name.clone(), pt.clone());
|
||||
}
|
||||
self.overrides = Rc::new(config.overrides);
|
||||
|
||||
debug!("set configuration object");
|
||||
|
||||
if let Some(ratelimits_config) = self
|
||||
.config
|
||||
.as_mut()
|
||||
.and_then(|config| config.ratelimits.as_mut())
|
||||
{
|
||||
ratelimit::ratelimits(Some(std::mem::take(ratelimits_config)));
|
||||
}
|
||||
|
||||
if let Some(prompt_guards) = self
|
||||
.config
|
||||
.as_mut()
|
||||
.and_then(|config| config.prompt_guards.as_mut())
|
||||
{
|
||||
self.prompt_guards = Rc::new(Some(std::mem::take(prompt_guards)));
|
||||
}
|
||||
for pt in config.prompt_targets {
|
||||
self.prompt_targets
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(pt.name.clone(), pt.clone());
|
||||
}
|
||||
|
||||
ratelimit::ratelimits(config.ratelimits);
|
||||
|
||||
if let Some(prompt_guards) = config.prompt_guards {
|
||||
self.prompt_guards = Rc::new(Some(prompt_guards))
|
||||
}
|
||||
|
||||
match config.llm_providers.try_into() {
|
||||
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
|
||||
Err(err) => panic!("{err}"),
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
|
|
@ -269,6 +264,11 @@ impl RootContext for FilterContext {
|
|||
Rc::clone(&self.prompt_targets),
|
||||
Rc::clone(&self.prompt_guards),
|
||||
Rc::clone(&self.overrides),
|
||||
Rc::clone(
|
||||
self.llm_providers
|
||||
.as_ref()
|
||||
.expect("LLM Providers must exist when Streams are being created"),
|
||||
),
|
||||
)))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,47 +1,69 @@
|
|||
#[non_exhaustive]
|
||||
pub struct LlmProviders;
|
||||
use public_types::configuration::LlmProvider;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LlmProviders {
|
||||
providers: HashMap<String, Rc<LlmProvider>>,
|
||||
default: Option<Rc<LlmProvider>>,
|
||||
}
|
||||
|
||||
impl LlmProviders {
|
||||
pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider {
|
||||
name: "openai",
|
||||
api_key_header: "x-arch-openai-api-key",
|
||||
model: "gpt-3.5-turbo",
|
||||
};
|
||||
pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider {
|
||||
name: "mistral",
|
||||
api_key_header: "x-arch-mistral-api-key",
|
||||
model: "mistral-large-latest",
|
||||
};
|
||||
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc<LlmProvider>> {
|
||||
self.providers.iter()
|
||||
}
|
||||
|
||||
pub const VARIANTS: &'static [LlmProvider<'static>] =
|
||||
&[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER];
|
||||
}
|
||||
pub fn default(&self) -> Option<Rc<LlmProvider>> {
|
||||
self.default.as_ref().map(|rc| rc.clone())
|
||||
}
|
||||
|
||||
pub struct LlmProvider<'prov> {
|
||||
name: &'prov str,
|
||||
api_key_header: &'prov str,
|
||||
model: &'prov str,
|
||||
}
|
||||
|
||||
impl AsRef<str> for LlmProvider<'_> {
|
||||
fn as_ref(&self) -> &str {
|
||||
self.name
|
||||
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
|
||||
self.providers.get(name).map(|rc| rc.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LlmProvider<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name)
|
||||
}
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum LlmProvidersNewError {
|
||||
#[error("There must be at least one LLM Provider")]
|
||||
EmptySource,
|
||||
#[error("There must be at most one default LLM Provider")]
|
||||
MoreThanOneDefault,
|
||||
#[error("\'{0}\' is not a unique name")]
|
||||
DuplicateName(String),
|
||||
}
|
||||
|
||||
impl LlmProvider<'_> {
|
||||
pub fn api_key_header(&self) -> &str {
|
||||
self.api_key_header
|
||||
}
|
||||
impl TryFrom<Vec<LlmProvider>> for LlmProviders {
|
||||
type Error = LlmProvidersNewError;
|
||||
|
||||
pub fn choose_model(&self) -> &str {
|
||||
// In the future this can be a more complex function balancing reliability, cost, performance, etc.
|
||||
self.model
|
||||
fn try_from(llm_providers_config: Vec<LlmProvider>) -> Result<Self, Self::Error> {
|
||||
if llm_providers_config.is_empty() {
|
||||
return Err(LlmProvidersNewError::EmptySource);
|
||||
}
|
||||
|
||||
let mut llm_providers = LlmProviders {
|
||||
providers: HashMap::new(),
|
||||
default: None,
|
||||
};
|
||||
|
||||
for llm_provider in llm_providers_config {
|
||||
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
|
||||
if llm_provider.default.unwrap_or_default() {
|
||||
match llm_providers.default {
|
||||
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
|
||||
None => llm_providers.default = Some(Rc::clone(&llm_provider)),
|
||||
}
|
||||
}
|
||||
|
||||
// Insert and check that there is no other provider with the same name.
|
||||
let name = llm_provider.name.clone();
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(name.clone(), llm_provider)
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
}
|
||||
}
|
||||
Ok(llm_providers)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,10 +54,7 @@ impl RatelimitMap {
|
|||
for ratelimit_config in ratelimits_config {
|
||||
let limit = DefaultKeyedRateLimiter::keyed(get_quota(ratelimit_config.limit));
|
||||
|
||||
match new_ratelimit_map
|
||||
.datastore
|
||||
.get_mut(&ratelimit_config.provider)
|
||||
{
|
||||
match new_ratelimit_map.datastore.get_mut(&ratelimit_config.model) {
|
||||
Some(limits) => match limits.get_mut(&ratelimit_config.selector) {
|
||||
Some(_) => {
|
||||
panic!("repeated selector. Selectors per provider must be unique")
|
||||
|
|
@ -72,7 +69,7 @@ impl RatelimitMap {
|
|||
let new_hash_map = HashMap::from([(ratelimit_config.selector, limit)]);
|
||||
new_ratelimit_map
|
||||
.datastore
|
||||
.insert(ratelimit_config.provider, new_hash_map);
|
||||
.insert(ratelimit_config.model, new_hash_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -142,7 +139,7 @@ fn get_quota(limit: Limit) -> Quota {
|
|||
#[test]
|
||||
fn non_existent_provider_is_ok() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
|
|
@ -170,7 +167,7 @@ fn non_existent_provider_is_ok() {
|
|||
#[test]
|
||||
fn non_existent_key_is_ok() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
|
|
@ -198,7 +195,7 @@ fn non_existent_key_is_ok() {
|
|||
#[test]
|
||||
fn specific_limit_does_not_catch_non_specific_value() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
|
|
@ -226,7 +223,7 @@ fn specific_limit_does_not_catch_non_specific_value() {
|
|||
#[test]
|
||||
fn specific_limit_is_hit() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
|
|
@ -254,7 +251,7 @@ fn specific_limit_is_hit() {
|
|||
#[test]
|
||||
fn non_specific_key_has_different_limits_for_different_values() {
|
||||
let ratelimits_config = vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("only-key"),
|
||||
value: None,
|
||||
|
|
@ -308,7 +305,7 @@ fn non_specific_key_has_different_limits_for_different_values() {
|
|||
fn different_provider_can_have_different_limits_with_the_same_keys() {
|
||||
let ratelimits_config = vec![
|
||||
Ratelimit {
|
||||
provider: String::from("first_provider"),
|
||||
model: String::from("first_provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
|
|
@ -319,7 +316,7 @@ fn different_provider_can_have_different_limits_with_the_same_keys() {
|
|||
},
|
||||
},
|
||||
Ratelimit {
|
||||
provider: String::from("second_provider"),
|
||||
model: String::from("second_provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
|
|
@ -391,7 +388,7 @@ mod test {
|
|||
#[test]
|
||||
fn different_threads_have_same_ratelimit_data_structure() {
|
||||
let ratelimits_config = Some(vec![Ratelimit {
|
||||
provider: String::from("provider"),
|
||||
model: String::from("provider"),
|
||||
selector: configuration::Header {
|
||||
key: String::from("key"),
|
||||
value: Some(String::from("value")),
|
||||
|
|
|
|||
|
|
@ -1,13 +1,42 @@
|
|||
use crate::llm_providers::{LlmProvider, LlmProviders};
|
||||
use rand::{seq::SliceRandom, thread_rng};
|
||||
use std::rc::Rc;
|
||||
|
||||
pub fn get_llm_provider<'hostname>(deterministic: bool) -> &'static LlmProvider<'hostname> {
|
||||
if deterministic {
|
||||
&LlmProviders::OPENAI_PROVIDER
|
||||
} else {
|
||||
let mut rng = thread_rng();
|
||||
LlmProviders::VARIANTS
|
||||
.choose(&mut rng)
|
||||
.expect("There should always be at least one llm provider")
|
||||
use crate::llm_providers::LlmProviders;
|
||||
use public_types::configuration::LlmProvider;
|
||||
use rand::{seq::IteratorRandom, thread_rng};
|
||||
|
||||
pub enum ProviderHint {
|
||||
Default,
|
||||
Name(String),
|
||||
}
|
||||
|
||||
impl From<String> for ProviderHint {
|
||||
fn from(value: String) -> Self {
|
||||
match value.as_str() {
|
||||
"default" => ProviderHint::Default,
|
||||
_ => ProviderHint::Name(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_llm_provider(
|
||||
llm_providers: &LlmProviders,
|
||||
provider_hint: Option<ProviderHint>,
|
||||
) -> Rc<LlmProvider> {
|
||||
let maybe_provider = provider_hint.and_then(|hint| match hint {
|
||||
ProviderHint::Default => llm_providers.default(),
|
||||
// FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST?
|
||||
ProviderHint::Name(name) => llm_providers.get(&name),
|
||||
});
|
||||
|
||||
if let Some(provider) = maybe_provider {
|
||||
return provider;
|
||||
}
|
||||
|
||||
let mut rng = thread_rng();
|
||||
llm_providers
|
||||
.iter()
|
||||
.choose(&mut rng)
|
||||
.expect("There should always be at least one llm provider")
|
||||
.1
|
||||
.clone()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
use crate::consts::{
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_ROUTING_HEADER, ARC_FC_CLUSTER,
|
||||
DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO,
|
||||
MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
|
||||
ARC_FC_CLUSTER, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD,
|
||||
GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::{embeddings_store, WasmMetrics};
|
||||
use crate::llm_providers::{LlmProvider, LlmProviders};
|
||||
use crate::llm_providers::LlmProviders;
|
||||
use crate::ratelimit::Header;
|
||||
use crate::stats::IncrementingMetric;
|
||||
use crate::tokenizer;
|
||||
use crate::{ratelimit, routing};
|
||||
use crate::{ratelimit, routing, tokenizer};
|
||||
use acap::cos;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
|
|
@ -23,6 +22,7 @@ use public_types::common_types::{
|
|||
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
|
||||
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||
};
|
||||
use public_types::configuration::LlmProvider;
|
||||
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
|
||||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
|
|
@ -53,16 +53,17 @@ pub struct CallContext {
|
|||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
pub context_id: u32,
|
||||
pub metrics: Rc<WasmMetrics>,
|
||||
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
pub overrides: Rc<Option<Overrides>>,
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
ratelimit_selector: Option<Header>,
|
||||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
chat_completions_request: bool,
|
||||
llm_provider: Option<&'static LlmProvider<'static>>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
llm_provider: Option<Rc<LlmProvider>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
}
|
||||
|
||||
|
|
@ -73,6 +74,7 @@ impl StreamContext {
|
|||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
) -> Self {
|
||||
StreamContext {
|
||||
context_id,
|
||||
|
|
@ -83,6 +85,7 @@ impl StreamContext {
|
|||
streaming_response: false,
|
||||
response_tokens: 0,
|
||||
chat_completions_request: false,
|
||||
llm_providers,
|
||||
llm_provider: None,
|
||||
prompt_guards,
|
||||
overrides,
|
||||
|
|
@ -90,27 +93,35 @@ impl StreamContext {
|
|||
}
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
self.llm_provider
|
||||
.as_ref()
|
||||
.expect("the provider should be set when asked for it")
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
let provider_hint = self
|
||||
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.map(|provider_name| provider_name.into());
|
||||
|
||||
self.llm_provider = Some(routing::get_llm_provider(
|
||||
&self.llm_providers,
|
||||
provider_hint,
|
||||
));
|
||||
}
|
||||
|
||||
fn add_routing_header(&mut self) {
|
||||
self.add_http_request_header(ARCH_ROUTING_HEADER, self.llm_provider().as_ref());
|
||||
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
|
||||
}
|
||||
|
||||
fn modify_auth_headers(&mut self) -> Result<(), String> {
|
||||
let llm_provider_api_key_value = self
|
||||
.get_http_request_header(self.llm_provider().api_key_header())
|
||||
.ok_or(format!("missing {} api key", self.llm_provider()))?;
|
||||
let llm_provider_api_key_value = self.llm_provider().access_key.as_ref().ok_or(format!(
|
||||
"No access key configured for selected LLM Provider \"{}\"",
|
||||
self.llm_provider()
|
||||
))?;
|
||||
|
||||
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
|
||||
|
||||
self.set_http_request_header("Authorization", Some(&authorization_header_value));
|
||||
|
||||
// sanitize passed in api keys
|
||||
for provider in LlmProviders::VARIANTS.iter() {
|
||||
self.set_http_request_header(provider.api_key_header(), None);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -728,29 +739,13 @@ impl StreamContext {
|
|||
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
||||
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
|
||||
|
||||
if prompt_guard_resp.jailbreak_verdict.is_some()
|
||||
&& prompt_guard_resp.jailbreak_verdict.unwrap()
|
||||
{
|
||||
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
|
||||
//TODO: handle other scenarios like forward to error target
|
||||
let default_err = "Jailbreak detected. Please refrain from discussing jailbreaking.";
|
||||
let error_msg = match self.prompt_guards.as_ref() {
|
||||
Some(prompt_guards) => match prompt_guards
|
||||
.input_guards
|
||||
.get(&public_types::configuration::GuardType::Jailbreak)
|
||||
{
|
||||
Some(jailbreak) => match jailbreak.on_exception.as_ref() {
|
||||
Some(on_exception_details) => match on_exception_details.message.as_ref() {
|
||||
Some(error_msg) => error_msg,
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
};
|
||||
|
||||
return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST));
|
||||
let msg = (*self.prompt_guards)
|
||||
.as_ref()
|
||||
.and_then(|pg| pg.jailbreak_on_exception_message())
|
||||
.unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking.");
|
||||
return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
|
||||
self.get_embeddings(callout_context);
|
||||
|
|
@ -900,11 +895,7 @@ impl HttpContext for StreamContext {
|
|||
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
|
||||
// the lifecycle of the http request and response.
|
||||
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
let provider_hint = self
|
||||
.get_http_request_header("x-arch-deterministic-provider")
|
||||
.is_some();
|
||||
self.llm_provider = Some(routing::get_llm_provider(provider_hint));
|
||||
|
||||
self.select_llm_provider();
|
||||
self.add_routing_header();
|
||||
if let Err(error) = self.modify_auth_headers() {
|
||||
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
|
||||
|
|
@ -959,7 +950,7 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
// Set the model based on the chosen LLM Provider
|
||||
deserialized_body.model = String::from(self.llm_provider().choose_model());
|
||||
deserialized_body.model = String::from(&self.llm_provider().model);
|
||||
|
||||
self.streaming_response = deserialized_body.stream;
|
||||
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue