Revert "Add support for multiple LLM Providers (#60)"

This reverts commit bd8206742a.
This commit is contained in:
Adil Hafeez 2024-09-25 08:15:22 -07:00
parent d970b214f4
commit 43d6bc80e9
12 changed files with 127 additions and 456 deletions

View file

@ -1,11 +1,11 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-bolt-ratelimit-selector";
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector";
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b";
pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const OPENAI_CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
pub const MODEL_SERVER_NAME: &str = "model_server";
pub const BOLT_ROUTING_HEADER: &str = "x-bolt-llm-provider";

View file

@ -4,9 +4,7 @@ use proxy_wasm::types::*;
mod consts;
mod filter_context;
mod llm_providers;
mod ratelimit;
mod routing;
mod stats;
mod stream_context;
mod tokenizer;

View file

@ -1,47 +0,0 @@
#[non_exhaustive]
pub struct LlmProviders;
impl LlmProviders {
pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider {
name: "openai",
api_key_header: "x-bolt-openai-api-key",
model: "gpt-3.5-turbo",
};
pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider {
name: "mistral",
api_key_header: "x-bolt-mistral-api-key",
model: "mistral-large-latest",
};
pub const VARIANTS: &'static [LlmProvider<'static>] =
&[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER];
}
pub struct LlmProvider<'prov> {
name: &'prov str,
api_key_header: &'prov str,
model: &'prov str,
}
impl AsRef<str> for LlmProvider<'_> {
fn as_ref(&self) -> &str {
self.name
}
}
impl std::fmt::Display for LlmProvider<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
}
impl LlmProvider<'_> {
pub fn api_key_header(&self) -> &str {
self.api_key_header
}
pub fn choose_model(&self) -> &str {
// In the future this can be a more complex function balancing reliability, cost, performance, etc.
self.model
}
}

View file

@ -1,13 +0,0 @@
use crate::llm_providers::{LlmProvider, LlmProviders};
use rand::{seq::SliceRandom, thread_rng};
pub fn get_llm_provider<'hostname>(deterministic: bool) -> &'static LlmProvider<'hostname> {
if deterministic {
&LlmProviders::OPENAI_PROVIDER
} else {
let mut rng = thread_rng();
LlmProviders::VARIANTS
.choose(&mut rng)
.expect("There should always be at least one llm provider")
}
}

View file

@ -1,14 +1,13 @@
use crate::consts::{
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, BOLT_ROUTING_HEADER, DEFAULT_EMBEDDING_MODEL,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, OPENAI_CHAT_COMPLETIONS_PATH,
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{embeddings_store, WasmMetrics};
use crate::llm_providers::{LlmProvider, LlmProviders};
use crate::ratelimit;
use crate::ratelimit::Header;
use crate::stats::IncrementingMetric;
use crate::tokenizer;
use crate::{ratelimit, routing};
use acap::cos;
use http::StatusCode;
use log::{debug, info, warn};
@ -57,11 +56,11 @@ pub struct StreamContext {
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
pub overrides: Rc<Option<Overrides>>,
callouts: HashMap<u32, CallContext>,
host_header: Option<String>,
ratelimit_selector: Option<Header>,
streaming_response: bool,
response_tokens: usize,
chat_completions_request: bool,
llm_provider: Option<&'static LlmProvider<'static>>,
prompt_guards: Rc<Option<PromptGuards>>,
}
@ -78,39 +77,18 @@ impl StreamContext {
metrics,
prompt_targets,
callouts: HashMap::new(),
host_header: None,
ratelimit_selector: None,
streaming_response: false,
response_tokens: 0,
chat_completions_request: false,
llm_provider: None,
prompt_guards,
overrides,
}
}
fn llm_provider(&self) -> &LlmProvider {
self.llm_provider
.expect("the provider should be set when asked for it")
}
fn add_routing_header(&mut self) {
self.add_http_request_header(BOLT_ROUTING_HEADER, self.llm_provider().as_ref());
}
fn modify_auth_headers(&mut self) -> Result<(), String> {
let llm_provider_api_key_value = self
.get_http_request_header(self.llm_provider().api_key_header())
.ok_or(format!("missing {} api key", self.llm_provider()))?;
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
self.set_http_request_header("Authorization", Some(&authorization_header_value));
// sanitize passed in api keys
for provider in LlmProviders::VARIANTS.iter() {
self.set_http_request_header(provider.api_key_header(), None);
}
Ok(())
fn save_host_header(&mut self) {
// Save the host header to be used by filter logic later on.
self.host_header = self.get_http_request_header(":host");
}
fn delete_content_length_header(&mut self) {
@ -121,6 +99,19 @@ impl StreamContext {
self.set_http_request_header("content-length", None);
}
fn modify_path_header(&mut self) {
match self.get_http_request_header(":path") {
// The gateway can start gathering information necessary for routing. For now change the path to an
// OpenAI API path.
Some(path) if path == "/llmrouting" => {
self.set_http_request_header(":path", Some(OPENAI_CHAT_COMPLETIONS_PATH));
self.chat_completions_request = true;
}
// Otherwise let the filter continue.
_ => (),
}
}
fn save_ratelimit_header(&mut self) {
self.ratelimit_selector = self
.get_http_request_header(RATELIMIT_SELECTOR_HEADER_KEY)
@ -246,7 +237,6 @@ impl StreamContext {
token_id
);
self.metrics.active_http_calls.increment(1);
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
if self.callouts.insert(token_id, callout_context).is_some() {
@ -441,7 +431,6 @@ impl StreamContext {
BOLT_FC_CLUSTER, token_id
);
self.metrics.active_http_calls.increment(1);
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
callout_context.prompt_target_name = Some(prompt_target.name);
if self.callouts.insert(token_id, callout_context).is_some() {
@ -449,6 +438,7 @@ impl StreamContext {
}
}
}
self.metrics.active_http_calls.increment(1);
}
fn function_resolver_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
@ -605,7 +595,7 @@ impl StreamContext {
});
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: callout_context.request_body.model,
model: GPT_35_TURBO.to_string(),
messages,
tools: None,
stream: callout_context.request_body.stream,
@ -761,24 +751,11 @@ impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
let provider_hint = self
.get_http_request_header("x-bolt-deterministic-provider")
.is_some();
self.llm_provider = Some(routing::get_llm_provider(provider_hint));
self.add_routing_header();
if let Err(error) = self.modify_auth_headers() {
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
}
self.save_host_header();
self.delete_content_length_header();
self.modify_path_header();
self.save_ratelimit_header();
debug!(
"S[{}] req_headers={:?}",
self.context_id,
self.get_http_request_headers()
);
Action::Continue
}
@ -819,9 +796,6 @@ impl HttpContext for StreamContext {
}
};
// Set the model based on the chosen LLM Provider
deserialized_body.model = String::from(self.llm_provider().choose_model());
self.streaming_response = deserialized_body.stream;
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
deserialized_body.stream_options = Some(StreamOptions {
@ -943,21 +917,15 @@ impl HttpContext for StreamContext {
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
if !self.chat_completions_request {
return Action::Continue;
}
debug!(
"recv [S={}] bytes={} end_stream={}",
self.context_id, body_size, end_of_stream
);
if !self.chat_completions_request {
if let Some(body_str) = self
.get_http_response_body(0, body_size)
.and_then(|bytes| String::from_utf8(bytes).ok())
{
debug!("recv [S={}] body_str={}", self.context_id, body_str);
}
return Action::Continue;
}
if !end_of_stream && !self.streaming_response {
return Action::Pause;
}