mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 21:02:56 +02:00
Add the ability to use LLM Providers from the Arch config (#112)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
1b57a49c9d
commit
8ea917aae5
16 changed files with 295 additions and 210 deletions
|
|
@ -1,14 +1,13 @@
|
|||
use crate::consts::{
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_ROUTING_HEADER, ARC_FC_CLUSTER,
|
||||
DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO,
|
||||
MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
|
||||
ARC_FC_CLUSTER, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD,
|
||||
GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::{embeddings_store, WasmMetrics};
|
||||
use crate::llm_providers::{LlmProvider, LlmProviders};
|
||||
use crate::llm_providers::LlmProviders;
|
||||
use crate::ratelimit::Header;
|
||||
use crate::stats::IncrementingMetric;
|
||||
use crate::tokenizer;
|
||||
use crate::{ratelimit, routing};
|
||||
use crate::{ratelimit, routing, tokenizer};
|
||||
use acap::cos;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
|
|
@ -23,6 +22,7 @@ use public_types::common_types::{
|
|||
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
|
||||
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||
};
|
||||
use public_types::configuration::LlmProvider;
|
||||
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
|
||||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
|
|
@ -53,16 +53,17 @@ pub struct CallContext {
|
|||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
pub context_id: u32,
|
||||
pub metrics: Rc<WasmMetrics>,
|
||||
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
pub overrides: Rc<Option<Overrides>>,
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
ratelimit_selector: Option<Header>,
|
||||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
chat_completions_request: bool,
|
||||
llm_provider: Option<&'static LlmProvider<'static>>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
llm_provider: Option<Rc<LlmProvider>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
}
|
||||
|
||||
|
|
@ -73,6 +74,7 @@ impl StreamContext {
|
|||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
) -> Self {
|
||||
StreamContext {
|
||||
context_id,
|
||||
|
|
@ -83,6 +85,7 @@ impl StreamContext {
|
|||
streaming_response: false,
|
||||
response_tokens: 0,
|
||||
chat_completions_request: false,
|
||||
llm_providers,
|
||||
llm_provider: None,
|
||||
prompt_guards,
|
||||
overrides,
|
||||
|
|
@ -90,27 +93,35 @@ impl StreamContext {
|
|||
}
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
self.llm_provider
|
||||
.as_ref()
|
||||
.expect("the provider should be set when asked for it")
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
let provider_hint = self
|
||||
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.map(|provider_name| provider_name.into());
|
||||
|
||||
self.llm_provider = Some(routing::get_llm_provider(
|
||||
&self.llm_providers,
|
||||
provider_hint,
|
||||
));
|
||||
}
|
||||
|
||||
fn add_routing_header(&mut self) {
|
||||
self.add_http_request_header(ARCH_ROUTING_HEADER, self.llm_provider().as_ref());
|
||||
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
|
||||
}
|
||||
|
||||
fn modify_auth_headers(&mut self) -> Result<(), String> {
|
||||
let llm_provider_api_key_value = self
|
||||
.get_http_request_header(self.llm_provider().api_key_header())
|
||||
.ok_or(format!("missing {} api key", self.llm_provider()))?;
|
||||
let llm_provider_api_key_value = self.llm_provider().access_key.as_ref().ok_or(format!(
|
||||
"No access key configured for selected LLM Provider \"{}\"",
|
||||
self.llm_provider()
|
||||
))?;
|
||||
|
||||
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
|
||||
|
||||
self.set_http_request_header("Authorization", Some(&authorization_header_value));
|
||||
|
||||
// sanitize passed in api keys
|
||||
for provider in LlmProviders::VARIANTS.iter() {
|
||||
self.set_http_request_header(provider.api_key_header(), None);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -728,29 +739,13 @@ impl StreamContext {
|
|||
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
||||
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
|
||||
|
||||
if prompt_guard_resp.jailbreak_verdict.is_some()
|
||||
&& prompt_guard_resp.jailbreak_verdict.unwrap()
|
||||
{
|
||||
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
|
||||
//TODO: handle other scenarios like forward to error target
|
||||
let default_err = "Jailbreak detected. Please refrain from discussing jailbreaking.";
|
||||
let error_msg = match self.prompt_guards.as_ref() {
|
||||
Some(prompt_guards) => match prompt_guards
|
||||
.input_guards
|
||||
.get(&public_types::configuration::GuardType::Jailbreak)
|
||||
{
|
||||
Some(jailbreak) => match jailbreak.on_exception.as_ref() {
|
||||
Some(on_exception_details) => match on_exception_details.message.as_ref() {
|
||||
Some(error_msg) => error_msg,
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
};
|
||||
|
||||
return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST));
|
||||
let msg = (*self.prompt_guards)
|
||||
.as_ref()
|
||||
.and_then(|pg| pg.jailbreak_on_exception_message())
|
||||
.unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking.");
|
||||
return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
|
||||
self.get_embeddings(callout_context);
|
||||
|
|
@ -900,11 +895,7 @@ impl HttpContext for StreamContext {
|
|||
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
|
||||
// the lifecycle of the http request and response.
|
||||
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
let provider_hint = self
|
||||
.get_http_request_header("x-arch-deterministic-provider")
|
||||
.is_some();
|
||||
self.llm_provider = Some(routing::get_llm_provider(provider_hint));
|
||||
|
||||
self.select_llm_provider();
|
||||
self.add_routing_header();
|
||||
if let Err(error) = self.modify_auth_headers() {
|
||||
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
|
||||
|
|
@ -959,7 +950,7 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
// Set the model based on the chosen LLM Provider
|
||||
deserialized_body.model = String::from(self.llm_provider().choose_model());
|
||||
deserialized_body.model = String::from(&self.llm_provider().model);
|
||||
|
||||
self.streaming_response = deserialized_body.stream;
|
||||
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue