llm listener split (#155)

This commit is contained in:
Adil Hafeez 2024-10-09 15:47:32 -07:00 committed by GitHub
parent 8b5db45507
commit e81ca8d5cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 305 additions and 54 deletions

View file

@ -1,8 +1,9 @@
use crate::consts::{
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY,
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER,
ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH,
DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
@ -26,7 +27,7 @@ use public_types::common_types::{
PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest,
ZeroShotClassificationResponse,
};
use public_types::configuration::LlmProvider;
use public_types::configuration::{GatewayMode, LlmProvider};
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -93,7 +94,7 @@ pub struct StreamContext {
metrics: Rc<WasmMetrics>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
embeddings_store: Rc<EmbeddingsStore>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
overrides: Rc<Option<Overrides>>,
callouts: RefCell<HashMap<u32, StreamCallContext>>,
tool_calls: Option<Vec<ToolCall>>,
@ -110,6 +111,7 @@ pub struct StreamContext {
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
request_id: Option<String>,
mode: GatewayMode,
}
impl StreamContext {
@ -122,7 +124,8 @@ impl StreamContext {
prompt_guards: Rc<PromptGuards>,
overrides: Rc<Option<Overrides>>,
llm_providers: Rc<LlmProviders>,
embeddings_store: Rc<EmbeddingsStore>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
mode: GatewayMode,
) -> Self {
StreamContext {
context_id,
@ -146,6 +149,7 @@ impl StreamContext {
prompt_guards,
overrides,
request_id: None,
mode,
}
}
fn llm_provider(&self) -> &LlmProvider {
@ -154,19 +158,35 @@ impl StreamContext {
.expect("the provider should be set when asked for it")
}
fn embeddings_store(&self) -> &EmbeddingsStore {
self.embeddings_store
.as_ref()
.expect("embeddings store is not set")
}
fn select_llm_provider(&mut self) {
let provider_hint = self
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.map(|provider_name| provider_name.into());
debug!("llm provider hint: {:?}", provider_hint);
self.llm_provider = Some(routing::get_llm_provider(
&self.llm_providers,
provider_hint,
));
debug!("selected llm: {}", self.llm_provider.as_ref().unwrap().name);
}
fn add_routing_header(&mut self) {
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
match self.mode {
GatewayMode::Prompt => {
// in prompt gateway mode, we need to route to llm upstream listener
self.add_http_request_header(ARCH_UPSTREAM_HOST_HEADER, ARCH_LLM_UPSTREAM_LISTENER);
}
_ => {
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
}
}
}
fn modify_auth_headers(&mut self) -> Result<(), ServerError> {
@ -247,7 +267,7 @@ impl StreamContext {
// exclude default prompt target
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
.map(|(prompt_name, _)| {
let pte = match self.embeddings_store.get(prompt_name) {
let pte = match self.embeddings_store().get(prompt_name) {
Some(embeddings) => embeddings,
None => {
warn!(
@ -901,32 +921,37 @@ impl StreamContext {
debug!("arch => openai request body: {}", json_string);
// Tokenize and Ratelimit.
if let Some(selector) = self.ratelimit_selector.take() {
if let Ok(token_count) =
tokenizer::token_count(&chat_completions_request.model, &json_string)
{
match ratelimit::ratelimits(None).read().unwrap().check_limit(
chat_completions_request.model,
selector,
NonZero::new(token_count as u32).unwrap(),
) {
Ok(_) => (),
Err(err) => {
self.send_server_error(
ServerError::ExceededRatelimit(err),
Some(StatusCode::TOO_MANY_REQUESTS),
);
self.metrics.ratelimited_rq.increment(1);
return;
}
}
}
if let Err(e) = self.enforce_ratelimits(&chat_completions_request.model, &json_string) {
self.send_server_error(
ServerError::ExceededRatelimit(e),
Some(StatusCode::TOO_MANY_REQUESTS),
);
self.metrics.ratelimited_rq.increment(1);
return;
}
self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes());
self.resume_http_request();
}
fn enforce_ratelimits(
&mut self,
model: &str,
json_string: &str,
) -> Result<(), ratelimit::Error> {
if let Some(selector) = self.ratelimit_selector.take() {
// Tokenize and Ratelimit.
if let Ok(token_count) = tokenizer::token_count(model, &json_string) {
ratelimit::ratelimits(None).read().unwrap().check_limit(
model.to_owned(),
selector,
NonZero::new(token_count as u32).unwrap(),
)?;
}
}
Ok(())
}
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
debug!("response received for arch guard");
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
@ -1140,6 +1165,41 @@ impl HttpContext for StreamContext {
return Action::Pause;
}
};
self.is_chat_completions_request = true;
if self.mode == GatewayMode::Llm {
debug!("llm gateway mode, skipping over all prompt targets");
// remove metadata from the request body
deserialized_body.metadata = None;
// delete model key from message array
for message in deserialized_body.messages.iter_mut() {
message.model = None;
}
deserialized_body
.model
.clone_from(&self.llm_provider.as_ref().unwrap().model);
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
// enforce ratelimits
if let Err(e) =
self.enforce_ratelimits(&deserialized_body.model, &chat_completion_request_str)
{
self.send_server_error(
ServerError::ExceededRatelimit(e),
Some(StatusCode::TOO_MANY_REQUESTS),
);
self.metrics.ratelimited_rq.increment(1);
return Action::Continue;
}
debug!(
"arch => {:?}, body: {}",
deserialized_body.model, chat_completion_request_str
);
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
return Action::Continue;
}
self.arch_state = match deserialized_body.metadata {
Some(ref metadata) => {
@ -1154,7 +1214,6 @@ impl HttpContext for StreamContext {
None => None,
};
self.is_chat_completions_request = true;
// Set the model based on the chosen LLM Provider
deserialized_body.model = String::from(&self.llm_provider().model);