mirror of
https://github.com/katanemo/plano.git
synced 2026-05-18 13:45:15 +02:00
Revert "Add support for multiple LLM Providers (#60)"
This reverts commit bd8206742a.
This commit is contained in:
parent
d970b214f4
commit
43d6bc80e9
12 changed files with 127 additions and 456 deletions
|
|
@ -1,11 +1,11 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
|
||||
pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-bolt-ratelimit-selector";
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector";
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
|
||||
pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b";
|
||||
pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
||||
pub const OPENAI_CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const MODEL_SERVER_NAME: &str = "model_server";
|
||||
pub const BOLT_ROUTING_HEADER: &str = "x-bolt-llm-provider";
|
||||
|
|
|
|||
|
|
@ -4,9 +4,7 @@ use proxy_wasm::types::*;
|
|||
|
||||
mod consts;
|
||||
mod filter_context;
|
||||
mod llm_providers;
|
||||
mod ratelimit;
|
||||
mod routing;
|
||||
mod stats;
|
||||
mod stream_context;
|
||||
mod tokenizer;
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
#[non_exhaustive]
|
||||
pub struct LlmProviders;
|
||||
|
||||
impl LlmProviders {
|
||||
pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider {
|
||||
name: "openai",
|
||||
api_key_header: "x-bolt-openai-api-key",
|
||||
model: "gpt-3.5-turbo",
|
||||
};
|
||||
pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider {
|
||||
name: "mistral",
|
||||
api_key_header: "x-bolt-mistral-api-key",
|
||||
model: "mistral-large-latest",
|
||||
};
|
||||
|
||||
pub const VARIANTS: &'static [LlmProvider<'static>] =
|
||||
&[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER];
|
||||
}
|
||||
|
||||
pub struct LlmProvider<'prov> {
|
||||
name: &'prov str,
|
||||
api_key_header: &'prov str,
|
||||
model: &'prov str,
|
||||
}
|
||||
|
||||
impl AsRef<str> for LlmProvider<'_> {
|
||||
fn as_ref(&self) -> &str {
|
||||
self.name
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LlmProvider<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name)
|
||||
}
|
||||
}
|
||||
|
||||
impl LlmProvider<'_> {
|
||||
pub fn api_key_header(&self) -> &str {
|
||||
self.api_key_header
|
||||
}
|
||||
|
||||
pub fn choose_model(&self) -> &str {
|
||||
// In the future this can be a more complex function balancing reliability, cost, performance, etc.
|
||||
self.model
|
||||
}
|
||||
}
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
use crate::llm_providers::{LlmProvider, LlmProviders};
|
||||
use rand::{seq::SliceRandom, thread_rng};
|
||||
|
||||
pub fn get_llm_provider<'hostname>(deterministic: bool) -> &'static LlmProvider<'hostname> {
|
||||
if deterministic {
|
||||
&LlmProviders::OPENAI_PROVIDER
|
||||
} else {
|
||||
let mut rng = thread_rng();
|
||||
LlmProviders::VARIANTS
|
||||
.choose(&mut rng)
|
||||
.expect("There should always be at least one llm provider")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,14 +1,13 @@
|
|||
use crate::consts::{
|
||||
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, BOLT_ROUTING_HEADER, DEFAULT_EMBEDDING_MODEL,
|
||||
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, OPENAI_CHAT_COMPLETIONS_PATH,
|
||||
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::{embeddings_store, WasmMetrics};
|
||||
use crate::llm_providers::{LlmProvider, LlmProviders};
|
||||
use crate::ratelimit;
|
||||
use crate::ratelimit::Header;
|
||||
use crate::stats::IncrementingMetric;
|
||||
use crate::tokenizer;
|
||||
use crate::{ratelimit, routing};
|
||||
use acap::cos;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
|
|
@ -57,11 +56,11 @@ pub struct StreamContext {
|
|||
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
pub overrides: Rc<Option<Overrides>>,
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
host_header: Option<String>,
|
||||
ratelimit_selector: Option<Header>,
|
||||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
chat_completions_request: bool,
|
||||
llm_provider: Option<&'static LlmProvider<'static>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
}
|
||||
|
||||
|
|
@ -78,39 +77,18 @@ impl StreamContext {
|
|||
metrics,
|
||||
prompt_targets,
|
||||
callouts: HashMap::new(),
|
||||
host_header: None,
|
||||
ratelimit_selector: None,
|
||||
streaming_response: false,
|
||||
response_tokens: 0,
|
||||
chat_completions_request: false,
|
||||
llm_provider: None,
|
||||
prompt_guards,
|
||||
overrides,
|
||||
}
|
||||
}
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
self.llm_provider
|
||||
.expect("the provider should be set when asked for it")
|
||||
}
|
||||
|
||||
fn add_routing_header(&mut self) {
|
||||
self.add_http_request_header(BOLT_ROUTING_HEADER, self.llm_provider().as_ref());
|
||||
}
|
||||
|
||||
fn modify_auth_headers(&mut self) -> Result<(), String> {
|
||||
let llm_provider_api_key_value = self
|
||||
.get_http_request_header(self.llm_provider().api_key_header())
|
||||
.ok_or(format!("missing {} api key", self.llm_provider()))?;
|
||||
|
||||
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
|
||||
|
||||
self.set_http_request_header("Authorization", Some(&authorization_header_value));
|
||||
|
||||
// sanitize passed in api keys
|
||||
for provider in LlmProviders::VARIANTS.iter() {
|
||||
self.set_http_request_header(provider.api_key_header(), None);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
fn save_host_header(&mut self) {
|
||||
// Save the host header to be used by filter logic later on.
|
||||
self.host_header = self.get_http_request_header(":host");
|
||||
}
|
||||
|
||||
fn delete_content_length_header(&mut self) {
|
||||
|
|
@ -121,6 +99,19 @@ impl StreamContext {
|
|||
self.set_http_request_header("content-length", None);
|
||||
}
|
||||
|
||||
fn modify_path_header(&mut self) {
|
||||
match self.get_http_request_header(":path") {
|
||||
// The gateway can start gathering information necessary for routing. For now change the path to an
|
||||
// OpenAI API path.
|
||||
Some(path) if path == "/llmrouting" => {
|
||||
self.set_http_request_header(":path", Some(OPENAI_CHAT_COMPLETIONS_PATH));
|
||||
self.chat_completions_request = true;
|
||||
}
|
||||
// Otherwise let the filter continue.
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn save_ratelimit_header(&mut self) {
|
||||
self.ratelimit_selector = self
|
||||
.get_http_request_header(RATELIMIT_SELECTOR_HEADER_KEY)
|
||||
|
|
@ -246,7 +237,6 @@ impl StreamContext {
|
|||
token_id
|
||||
);
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
|
||||
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
|
|
@ -441,7 +431,6 @@ impl StreamContext {
|
|||
BOLT_FC_CLUSTER, token_id
|
||||
);
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
|
||||
callout_context.prompt_target_name = Some(prompt_target.name);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
|
|
@ -449,6 +438,7 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
}
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
}
|
||||
|
||||
fn function_resolver_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
|
|
@ -605,7 +595,7 @@ impl StreamContext {
|
|||
});
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: callout_context.request_body.model,
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
|
|
@ -761,24 +751,11 @@ impl HttpContext for StreamContext {
|
|||
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
|
||||
// the lifecycle of the http request and response.
|
||||
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
let provider_hint = self
|
||||
.get_http_request_header("x-bolt-deterministic-provider")
|
||||
.is_some();
|
||||
self.llm_provider = Some(routing::get_llm_provider(provider_hint));
|
||||
|
||||
self.add_routing_header();
|
||||
if let Err(error) = self.modify_auth_headers() {
|
||||
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
self.save_host_header();
|
||||
self.delete_content_length_header();
|
||||
self.modify_path_header();
|
||||
self.save_ratelimit_header();
|
||||
|
||||
debug!(
|
||||
"S[{}] req_headers={:?}",
|
||||
self.context_id,
|
||||
self.get_http_request_headers()
|
||||
);
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
|
||||
|
|
@ -819,9 +796,6 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
// Set the model based on the chosen LLM Provider
|
||||
deserialized_body.model = String::from(self.llm_provider().choose_model());
|
||||
|
||||
self.streaming_response = deserialized_body.stream;
|
||||
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
|
||||
deserialized_body.stream_options = Some(StreamOptions {
|
||||
|
|
@ -943,21 +917,15 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
|
||||
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
||||
if !self.chat_completions_request {
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"recv [S={}] bytes={} end_stream={}",
|
||||
self.context_id, body_size, end_of_stream
|
||||
);
|
||||
|
||||
if !self.chat_completions_request {
|
||||
if let Some(body_str) = self
|
||||
.get_http_response_body(0, body_size)
|
||||
.and_then(|bytes| String::from_utf8(bytes).ok())
|
||||
{
|
||||
debug!("recv [S={}] body_str={}", self.context_id, body_str);
|
||||
}
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
if !end_of_stream && !self.streaming_response {
|
||||
return Action::Pause;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue