mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 21:02:56 +02:00
llm listener split (#155)
This commit is contained in:
parent
8b5db45507
commit
e81ca8d5cf
16 changed files with 305 additions and 54 deletions
|
|
@ -1,8 +1,9 @@
|
|||
use crate::consts::{
|
||||
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY,
|
||||
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER,
|
||||
ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
|
||||
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
|
||||
ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
|
||||
ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH,
|
||||
DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
||||
|
|
@ -26,7 +27,7 @@ use public_types::common_types::{
|
|||
PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest,
|
||||
ZeroShotClassificationResponse,
|
||||
};
|
||||
use public_types::configuration::LlmProvider;
|
||||
use public_types::configuration::{GatewayMode, LlmProvider};
|
||||
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
|
||||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
|
|
@ -93,7 +94,7 @@ pub struct StreamContext {
|
|||
metrics: Rc<WasmMetrics>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
embeddings_store: Rc<EmbeddingsStore>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
|
|
@ -110,6 +111,7 @@ pub struct StreamContext {
|
|||
llm_providers: Rc<LlmProviders>,
|
||||
llm_provider: Option<Rc<LlmProvider>>,
|
||||
request_id: Option<String>,
|
||||
mode: GatewayMode,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -122,7 +124,8 @@ impl StreamContext {
|
|||
prompt_guards: Rc<PromptGuards>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
embeddings_store: Rc<EmbeddingsStore>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
mode: GatewayMode,
|
||||
) -> Self {
|
||||
StreamContext {
|
||||
context_id,
|
||||
|
|
@ -146,6 +149,7 @@ impl StreamContext {
|
|||
prompt_guards,
|
||||
overrides,
|
||||
request_id: None,
|
||||
mode,
|
||||
}
|
||||
}
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
|
|
@ -154,19 +158,35 @@ impl StreamContext {
|
|||
.expect("the provider should be set when asked for it")
|
||||
}
|
||||
|
||||
fn embeddings_store(&self) -> &EmbeddingsStore {
|
||||
self.embeddings_store
|
||||
.as_ref()
|
||||
.expect("embeddings store is not set")
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
let provider_hint = self
|
||||
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.map(|provider_name| provider_name.into());
|
||||
|
||||
debug!("llm provider hint: {:?}", provider_hint);
|
||||
self.llm_provider = Some(routing::get_llm_provider(
|
||||
&self.llm_providers,
|
||||
provider_hint,
|
||||
));
|
||||
debug!("selected llm: {}", self.llm_provider.as_ref().unwrap().name);
|
||||
}
|
||||
|
||||
fn add_routing_header(&mut self) {
|
||||
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
|
||||
match self.mode {
|
||||
GatewayMode::Prompt => {
|
||||
// in prompt gateway mode, we need to route to llm upstream listener
|
||||
self.add_http_request_header(ARCH_UPSTREAM_HOST_HEADER, ARCH_LLM_UPSTREAM_LISTENER);
|
||||
}
|
||||
_ => {
|
||||
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn modify_auth_headers(&mut self) -> Result<(), ServerError> {
|
||||
|
|
@ -247,7 +267,7 @@ impl StreamContext {
|
|||
// exclude default prompt target
|
||||
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
|
||||
.map(|(prompt_name, _)| {
|
||||
let pte = match self.embeddings_store.get(prompt_name) {
|
||||
let pte = match self.embeddings_store().get(prompt_name) {
|
||||
Some(embeddings) => embeddings,
|
||||
None => {
|
||||
warn!(
|
||||
|
|
@ -901,32 +921,37 @@ impl StreamContext {
|
|||
debug!("arch => openai request body: {}", json_string);
|
||||
|
||||
// Tokenize and Ratelimit.
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
if let Ok(token_count) =
|
||||
tokenizer::token_count(&chat_completions_request.model, &json_string)
|
||||
{
|
||||
match ratelimit::ratelimits(None).read().unwrap().check_limit(
|
||||
chat_completions_request.model,
|
||||
selector,
|
||||
NonZero::new(token_count as u32).unwrap(),
|
||||
) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
self.send_server_error(
|
||||
ServerError::ExceededRatelimit(err),
|
||||
Some(StatusCode::TOO_MANY_REQUESTS),
|
||||
);
|
||||
self.metrics.ratelimited_rq.increment(1);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Err(e) = self.enforce_ratelimits(&chat_completions_request.model, &json_string) {
|
||||
self.send_server_error(
|
||||
ServerError::ExceededRatelimit(e),
|
||||
Some(StatusCode::TOO_MANY_REQUESTS),
|
||||
);
|
||||
self.metrics.ratelimited_rq.increment(1);
|
||||
return;
|
||||
}
|
||||
|
||||
self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
|
||||
fn enforce_ratelimits(
|
||||
&mut self,
|
||||
model: &str,
|
||||
json_string: &str,
|
||||
) -> Result<(), ratelimit::Error> {
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
// Tokenize and Ratelimit.
|
||||
if let Ok(token_count) = tokenizer::token_count(model, &json_string) {
|
||||
ratelimit::ratelimits(None).read().unwrap().check_limit(
|
||||
model.to_owned(),
|
||||
selector,
|
||||
NonZero::new(token_count as u32).unwrap(),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
debug!("response received for arch guard");
|
||||
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
||||
|
|
@ -1140,6 +1165,41 @@ impl HttpContext for StreamContext {
|
|||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
self.is_chat_completions_request = true;
|
||||
|
||||
if self.mode == GatewayMode::Llm {
|
||||
debug!("llm gateway mode, skipping over all prompt targets");
|
||||
|
||||
// remove metadata from the request body
|
||||
deserialized_body.metadata = None;
|
||||
// delete model key from message array
|
||||
for message in deserialized_body.messages.iter_mut() {
|
||||
message.model = None;
|
||||
}
|
||||
deserialized_body
|
||||
.model
|
||||
.clone_from(&self.llm_provider.as_ref().unwrap().model);
|
||||
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
|
||||
|
||||
// enforce ratelimits
|
||||
if let Err(e) =
|
||||
self.enforce_ratelimits(&deserialized_body.model, &chat_completion_request_str)
|
||||
{
|
||||
self.send_server_error(
|
||||
ServerError::ExceededRatelimit(e),
|
||||
Some(StatusCode::TOO_MANY_REQUESTS),
|
||||
);
|
||||
self.metrics.ratelimited_rq.increment(1);
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"arch => {:?}, body: {}",
|
||||
deserialized_body.model, chat_completion_request_str
|
||||
);
|
||||
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
self.arch_state = match deserialized_body.metadata {
|
||||
Some(ref metadata) => {
|
||||
|
|
@ -1154,7 +1214,6 @@ impl HttpContext for StreamContext {
|
|||
None => None,
|
||||
};
|
||||
|
||||
self.is_chat_completions_request = true;
|
||||
// Set the model based on the chosen LLM Provider
|
||||
deserialized_body.model = String::from(&self.llm_provider().model);
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue