Add the ability to use LLM Providers from the Arch config (#112)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-10-03 10:57:01 -07:00 committed by GitHub
parent 1b57a49c9d
commit 8ea917aae5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 295 additions and 210 deletions

View file

@ -1,14 +1,13 @@
use crate::consts::{
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_ROUTING_HEADER, ARC_FC_CLUSTER,
DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO,
MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
ARC_FC_CLUSTER, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD,
GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{embeddings_store, WasmMetrics};
use crate::llm_providers::{LlmProvider, LlmProviders};
use crate::llm_providers::LlmProviders;
use crate::ratelimit::Header;
use crate::stats::IncrementingMetric;
use crate::tokenizer;
use crate::{ratelimit, routing};
use crate::{ratelimit, routing, tokenizer};
use acap::cos;
use http::StatusCode;
use log::{debug, info, warn};
@ -23,6 +22,7 @@ use public_types::common_types::{
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
};
use public_types::configuration::LlmProvider;
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -53,16 +53,17 @@ pub struct CallContext {
}
pub struct StreamContext {
pub context_id: u32,
pub metrics: Rc<WasmMetrics>,
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
pub overrides: Rc<Option<Overrides>>,
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
overrides: Rc<Option<Overrides>>,
callouts: HashMap<u32, CallContext>,
ratelimit_selector: Option<Header>,
streaming_response: bool,
response_tokens: usize,
chat_completions_request: bool,
llm_provider: Option<&'static LlmProvider<'static>>,
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
prompt_guards: Rc<Option<PromptGuards>>,
}
@ -73,6 +74,7 @@ impl StreamContext {
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
prompt_guards: Rc<Option<PromptGuards>>,
overrides: Rc<Option<Overrides>>,
llm_providers: Rc<LlmProviders>,
) -> Self {
StreamContext {
context_id,
@ -83,6 +85,7 @@ impl StreamContext {
streaming_response: false,
response_tokens: 0,
chat_completions_request: false,
llm_providers,
llm_provider: None,
prompt_guards,
overrides,
@ -90,27 +93,35 @@ impl StreamContext {
}
fn llm_provider(&self) -> &LlmProvider {
self.llm_provider
.as_ref()
.expect("the provider should be set when asked for it")
}
fn select_llm_provider(&mut self) {
let provider_hint = self
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.map(|provider_name| provider_name.into());
self.llm_provider = Some(routing::get_llm_provider(
&self.llm_providers,
provider_hint,
));
}
fn add_routing_header(&mut self) {
self.add_http_request_header(ARCH_ROUTING_HEADER, self.llm_provider().as_ref());
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
}
fn modify_auth_headers(&mut self) -> Result<(), String> {
let llm_provider_api_key_value = self
.get_http_request_header(self.llm_provider().api_key_header())
.ok_or(format!("missing {} api key", self.llm_provider()))?;
let llm_provider_api_key_value = self.llm_provider().access_key.as_ref().ok_or(format!(
"No access key configured for selected LLM Provider \"{}\"",
self.llm_provider()
))?;
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
self.set_http_request_header("Authorization", Some(&authorization_header_value));
// sanitize passed in api keys
for provider in LlmProviders::VARIANTS.iter() {
self.set_http_request_header(provider.api_key_header(), None);
}
Ok(())
}
@ -728,29 +739,13 @@ impl StreamContext {
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
if prompt_guard_resp.jailbreak_verdict.is_some()
&& prompt_guard_resp.jailbreak_verdict.unwrap()
{
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
//TODO: handle other scenarios like forward to error target
let default_err = "Jailbreak detected. Please refrain from discussing jailbreaking.";
let error_msg = match self.prompt_guards.as_ref() {
Some(prompt_guards) => match prompt_guards
.input_guards
.get(&public_types::configuration::GuardType::Jailbreak)
{
Some(jailbreak) => match jailbreak.on_exception.as_ref() {
Some(on_exception_details) => match on_exception_details.message.as_ref() {
Some(error_msg) => error_msg,
None => default_err,
},
None => default_err,
},
None => default_err,
},
None => default_err,
};
return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST));
let msg = (*self.prompt_guards)
.as_ref()
.and_then(|pg| pg.jailbreak_on_exception_message())
.unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking.");
return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST));
}
self.get_embeddings(callout_context);
@ -900,11 +895,7 @@ impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
let provider_hint = self
.get_http_request_header("x-arch-deterministic-provider")
.is_some();
self.llm_provider = Some(routing::get_llm_provider(provider_hint));
self.select_llm_provider();
self.add_routing_header();
if let Err(error) = self.modify_auth_headers() {
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
@ -959,7 +950,7 @@ impl HttpContext for StreamContext {
};
// Set the model based on the chosen LLM Provider
deserialized_body.model = String::from(self.llm_provider().choose_model());
deserialized_body.model = String::from(&self.llm_provider().model);
self.streaming_response = deserialized_body.stream;
if deserialized_body.stream && deserialized_body.stream_options.is_none() {