mirror of
https://github.com/katanemo/plano.git
synced 2026-05-24 14:05:14 +02:00
Revert "Add support for multiple LLM Providers (#60)"
This reverts commit bd8206742a.
This commit is contained in:
parent
d970b214f4
commit
43d6bc80e9
12 changed files with 127 additions and 456 deletions
1
envoyfilter/Cargo.lock
generated
1
envoyfilter/Cargo.lock
generated
|
|
@ -745,7 +745,6 @@ dependencies = [
|
|||
"proxy-wasm",
|
||||
"proxy-wasm-test-framework",
|
||||
"public_types",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ http = "1.1.0"
|
|||
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
|
||||
tiktoken-rs = "0.5.9"
|
||||
acap = "0.3.0"
|
||||
rand = "0.8.5"
|
||||
|
||||
[dev-dependencies]
|
||||
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" }
|
||||
|
|
|
|||
|
|
@ -19,6 +19,15 @@ static_resources:
|
|||
route_config:
|
||||
name: local_routes
|
||||
virtual_hosts:
|
||||
- name: openai
|
||||
domains:
|
||||
- "api.openai.com"
|
||||
routes:
|
||||
- match:
|
||||
prefix: "/"
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: openai
|
||||
- name: local_service
|
||||
domains:
|
||||
- "*"
|
||||
|
|
@ -39,23 +48,28 @@ static_resources:
|
|||
- match:
|
||||
prefix: "/v1/chat/completions"
|
||||
headers:
|
||||
- name: "x-bolt-llm-provider"
|
||||
string_match:
|
||||
exact: openai
|
||||
name: "Authorization"
|
||||
present_match: true
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: openai
|
||||
timeout: 60s
|
||||
- match:
|
||||
prefix: "/v1/chat/completions"
|
||||
headers:
|
||||
- name: "x-bolt-llm-provider"
|
||||
string_match:
|
||||
exact: mistral
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: mistral
|
||||
cluster: mistral_7b_instruct
|
||||
timeout: 60s
|
||||
- match:
|
||||
prefix: "/embeddings"
|
||||
route:
|
||||
cluster: model_server
|
||||
- match:
|
||||
prefix: "/"
|
||||
direct_response:
|
||||
status: 200
|
||||
body:
|
||||
inline_string: "Inspect the HTTP header: custom-header.\n"
|
||||
http_filters:
|
||||
- name: envoy.filters.http.wasm
|
||||
typed_config:
|
||||
|
|
@ -108,31 +122,6 @@ static_resources:
|
|||
tls_params:
|
||||
tls_minimum_protocol_version: TLSv1_2
|
||||
tls_maximum_protocol_version: TLSv1_3
|
||||
- name: mistral
|
||||
connect_timeout: 5s
|
||||
dns_lookup_family: V4_ONLY
|
||||
type: LOGICAL_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
typed_extension_protocol_options:
|
||||
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
|
||||
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
|
||||
explicit_http_config:
|
||||
http2_protocol_options: {}
|
||||
load_assignment:
|
||||
cluster_name: mistral
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: api.mistral.ai
|
||||
port_value: 443
|
||||
hostname: "api.mistral.ai"
|
||||
transport_socket:
|
||||
name: envoy.transport_sockets.tls
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
sni: api.mistral.ai
|
||||
- name: model_server
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
|
|
|
|||
|
|
@ -1,233 +0,0 @@
|
|||
admin:
|
||||
address:
|
||||
socket_address: { address: 0.0.0.0, port_value: 9901 }
|
||||
static_resources:
|
||||
listeners:
|
||||
address:
|
||||
socket_address:
|
||||
address: 0.0.0.0
|
||||
port_value: 10000
|
||||
filter_chains:
|
||||
- filters:
|
||||
- name: envoy.filters.network.http_connection_manager
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
|
||||
stat_prefix: ingress_http
|
||||
codec_type: AUTO
|
||||
scheme_header_transformation:
|
||||
scheme_to_overwrite: https
|
||||
route_config:
|
||||
- name: bolt
|
||||
domains:
|
||||
- "*"
|
||||
routes:
|
||||
- match:
|
||||
headers:
|
||||
- name: "x-bolt-llm-provider"
|
||||
string_match:
|
||||
exact: openai
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: openai
|
||||
timeout: 60s
|
||||
- match:
|
||||
headers:
|
||||
- name: "x-bolt-llm-provider"
|
||||
string_match:
|
||||
exact: mistral
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: mistral
|
||||
timeout: 60s
|
||||
- match:
|
||||
prefix: "/embeddings"
|
||||
route:
|
||||
cluster: embeddingserver
|
||||
http_filters:
|
||||
- name: envoy.filters.http.wasm
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
|
||||
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
|
||||
value:
|
||||
config:
|
||||
name: "http_config"
|
||||
configuration:
|
||||
"@type": "type.googleapis.com/google.protobuf.StringValue"
|
||||
value: |
|
||||
default_prompt_endpoint: "127.0.0.1"
|
||||
load_balancing: "round_robin"
|
||||
timeout_ms: 5000
|
||||
|
||||
embedding_provider:
|
||||
name: "SentenceTransformer"
|
||||
model: "all-MiniLM-L6-v2"
|
||||
|
||||
llm_providers:
|
||||
|
||||
- name: open-ai-gpt-4
|
||||
api_key: "$OPEN_AI_API_KEY"
|
||||
model: gpt-4
|
||||
|
||||
- name: mistral_7b_instruct
|
||||
model: mistral-7b-instruct
|
||||
endpoint: http://mistral_7b_instruct:10001/v1/chat/completions
|
||||
default: true
|
||||
|
||||
|
||||
prompt_targets:
|
||||
|
||||
- type: context_resolver
|
||||
name: weather_forecast
|
||||
few_shot_examples:
|
||||
- what is the weather in New York?
|
||||
- how is the weather in San Francisco?
|
||||
- what is the forecast in Seattle?
|
||||
entities:
|
||||
- name: city
|
||||
required: true
|
||||
- name: days
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
vm_config:
|
||||
runtime: "envoy.wasm.runtime.v8"
|
||||
code:
|
||||
local:
|
||||
filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm"
|
||||
- name: envoy.filters.http.router
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
|
||||
clusters:
|
||||
# LLM Host
|
||||
# Embedding Providers
|
||||
# External LLM Providers
|
||||
- name: openai
|
||||
connect_timeout: 5s
|
||||
type: LOGICAL_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
typed_extension_protocol_options:
|
||||
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
|
||||
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
|
||||
explicit_http_config:
|
||||
http2_protocol_options: {}
|
||||
load_assignment:
|
||||
cluster_name: openai
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: api.openai.com
|
||||
port_value: 443
|
||||
hostname: "api.openai.com"
|
||||
transport_socket:
|
||||
name: envoy.transport_sockets.tls
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
sni: api.openai.com
|
||||
common_tls_context:
|
||||
tls_params:
|
||||
tls_minimum_protocol_version: TLSv1_2
|
||||
tls_maximum_protocol_version: TLSv1_3
|
||||
- name: mistral
|
||||
connect_timeout: 5s
|
||||
type: LOGICAL_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
typed_extension_protocol_options:
|
||||
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
|
||||
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
|
||||
explicit_http_config:
|
||||
http2_protocol_options: {}
|
||||
load_assignment:
|
||||
cluster_name: mistral
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: api.mistral.ai
|
||||
port_value: 443
|
||||
hostname: "api.mistral.ai"
|
||||
transport_socket:
|
||||
name: envoy.transport_sockets.tls
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
sni: api.mistral.ai
|
||||
common_tls_context:
|
||||
tls_params:
|
||||
tls_minimum_protocol_version: TLSv1_2
|
||||
tls_maximum_protocol_version: TLSv1_3
|
||||
- name: embeddingserver
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: embeddingserver
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: host.docker.internal
|
||||
port_value: 8000
|
||||
hostname: "embeddingserver"
|
||||
- name: weatherhost
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: weatherhost
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: host.docker.internal
|
||||
port_value: 8000
|
||||
hostname: "embeddingserver"
|
||||
- name: nerhost
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: nerhost
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: host.docker.internal
|
||||
port_value: 8000
|
||||
hostname: "embeddingserver"
|
||||
- name: qdrant
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: qdrant
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: qdrant
|
||||
port_value: 6333
|
||||
hostname: "qdrant"
|
||||
- name: mistral_7b_instruct
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: qdrant
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: mistral_7b_instruct
|
||||
port_value: 10001
|
||||
hostname: "mistral_7b_instruct"
|
||||
|
|
@ -1,11 +1,11 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
|
||||
pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-bolt-ratelimit-selector";
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector";
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
|
||||
pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b";
|
||||
pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
||||
pub const OPENAI_CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const MODEL_SERVER_NAME: &str = "model_server";
|
||||
pub const BOLT_ROUTING_HEADER: &str = "x-bolt-llm-provider";
|
||||
|
|
|
|||
|
|
@ -4,9 +4,7 @@ use proxy_wasm::types::*;
|
|||
|
||||
mod consts;
|
||||
mod filter_context;
|
||||
mod llm_providers;
|
||||
mod ratelimit;
|
||||
mod routing;
|
||||
mod stats;
|
||||
mod stream_context;
|
||||
mod tokenizer;
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
#[non_exhaustive]
|
||||
pub struct LlmProviders;
|
||||
|
||||
impl LlmProviders {
|
||||
pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider {
|
||||
name: "openai",
|
||||
api_key_header: "x-bolt-openai-api-key",
|
||||
model: "gpt-3.5-turbo",
|
||||
};
|
||||
pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider {
|
||||
name: "mistral",
|
||||
api_key_header: "x-bolt-mistral-api-key",
|
||||
model: "mistral-large-latest",
|
||||
};
|
||||
|
||||
pub const VARIANTS: &'static [LlmProvider<'static>] =
|
||||
&[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER];
|
||||
}
|
||||
|
||||
pub struct LlmProvider<'prov> {
|
||||
name: &'prov str,
|
||||
api_key_header: &'prov str,
|
||||
model: &'prov str,
|
||||
}
|
||||
|
||||
impl AsRef<str> for LlmProvider<'_> {
|
||||
fn as_ref(&self) -> &str {
|
||||
self.name
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LlmProvider<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name)
|
||||
}
|
||||
}
|
||||
|
||||
impl LlmProvider<'_> {
|
||||
pub fn api_key_header(&self) -> &str {
|
||||
self.api_key_header
|
||||
}
|
||||
|
||||
pub fn choose_model(&self) -> &str {
|
||||
// In the future this can be a more complex function balancing reliability, cost, performance, etc.
|
||||
self.model
|
||||
}
|
||||
}
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
use crate::llm_providers::{LlmProvider, LlmProviders};
|
||||
use rand::{seq::SliceRandom, thread_rng};
|
||||
|
||||
pub fn get_llm_provider<'hostname>(deterministic: bool) -> &'static LlmProvider<'hostname> {
|
||||
if deterministic {
|
||||
&LlmProviders::OPENAI_PROVIDER
|
||||
} else {
|
||||
let mut rng = thread_rng();
|
||||
LlmProviders::VARIANTS
|
||||
.choose(&mut rng)
|
||||
.expect("There should always be at least one llm provider")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,14 +1,13 @@
|
|||
use crate::consts::{
|
||||
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, BOLT_ROUTING_HEADER, DEFAULT_EMBEDDING_MODEL,
|
||||
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, OPENAI_CHAT_COMPLETIONS_PATH,
|
||||
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::{embeddings_store, WasmMetrics};
|
||||
use crate::llm_providers::{LlmProvider, LlmProviders};
|
||||
use crate::ratelimit;
|
||||
use crate::ratelimit::Header;
|
||||
use crate::stats::IncrementingMetric;
|
||||
use crate::tokenizer;
|
||||
use crate::{ratelimit, routing};
|
||||
use acap::cos;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
|
|
@ -57,11 +56,11 @@ pub struct StreamContext {
|
|||
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
pub overrides: Rc<Option<Overrides>>,
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
host_header: Option<String>,
|
||||
ratelimit_selector: Option<Header>,
|
||||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
chat_completions_request: bool,
|
||||
llm_provider: Option<&'static LlmProvider<'static>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
}
|
||||
|
||||
|
|
@ -78,39 +77,18 @@ impl StreamContext {
|
|||
metrics,
|
||||
prompt_targets,
|
||||
callouts: HashMap::new(),
|
||||
host_header: None,
|
||||
ratelimit_selector: None,
|
||||
streaming_response: false,
|
||||
response_tokens: 0,
|
||||
chat_completions_request: false,
|
||||
llm_provider: None,
|
||||
prompt_guards,
|
||||
overrides,
|
||||
}
|
||||
}
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
self.llm_provider
|
||||
.expect("the provider should be set when asked for it")
|
||||
}
|
||||
|
||||
fn add_routing_header(&mut self) {
|
||||
self.add_http_request_header(BOLT_ROUTING_HEADER, self.llm_provider().as_ref());
|
||||
}
|
||||
|
||||
fn modify_auth_headers(&mut self) -> Result<(), String> {
|
||||
let llm_provider_api_key_value = self
|
||||
.get_http_request_header(self.llm_provider().api_key_header())
|
||||
.ok_or(format!("missing {} api key", self.llm_provider()))?;
|
||||
|
||||
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
|
||||
|
||||
self.set_http_request_header("Authorization", Some(&authorization_header_value));
|
||||
|
||||
// sanitize passed in api keys
|
||||
for provider in LlmProviders::VARIANTS.iter() {
|
||||
self.set_http_request_header(provider.api_key_header(), None);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
fn save_host_header(&mut self) {
|
||||
// Save the host header to be used by filter logic later on.
|
||||
self.host_header = self.get_http_request_header(":host");
|
||||
}
|
||||
|
||||
fn delete_content_length_header(&mut self) {
|
||||
|
|
@ -121,6 +99,19 @@ impl StreamContext {
|
|||
self.set_http_request_header("content-length", None);
|
||||
}
|
||||
|
||||
fn modify_path_header(&mut self) {
|
||||
match self.get_http_request_header(":path") {
|
||||
// The gateway can start gathering information necessary for routing. For now change the path to an
|
||||
// OpenAI API path.
|
||||
Some(path) if path == "/llmrouting" => {
|
||||
self.set_http_request_header(":path", Some(OPENAI_CHAT_COMPLETIONS_PATH));
|
||||
self.chat_completions_request = true;
|
||||
}
|
||||
// Otherwise let the filter continue.
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn save_ratelimit_header(&mut self) {
|
||||
self.ratelimit_selector = self
|
||||
.get_http_request_header(RATELIMIT_SELECTOR_HEADER_KEY)
|
||||
|
|
@ -246,7 +237,6 @@ impl StreamContext {
|
|||
token_id
|
||||
);
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
|
||||
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
|
|
@ -441,7 +431,6 @@ impl StreamContext {
|
|||
BOLT_FC_CLUSTER, token_id
|
||||
);
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
|
||||
callout_context.prompt_target_name = Some(prompt_target.name);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
|
|
@ -449,6 +438,7 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
}
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
}
|
||||
|
||||
fn function_resolver_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
|
|
@ -605,7 +595,7 @@ impl StreamContext {
|
|||
});
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: callout_context.request_body.model,
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
|
|
@ -761,24 +751,11 @@ impl HttpContext for StreamContext {
|
|||
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
|
||||
// the lifecycle of the http request and response.
|
||||
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
let provider_hint = self
|
||||
.get_http_request_header("x-bolt-deterministic-provider")
|
||||
.is_some();
|
||||
self.llm_provider = Some(routing::get_llm_provider(provider_hint));
|
||||
|
||||
self.add_routing_header();
|
||||
if let Err(error) = self.modify_auth_headers() {
|
||||
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
self.save_host_header();
|
||||
self.delete_content_length_header();
|
||||
self.modify_path_header();
|
||||
self.save_ratelimit_header();
|
||||
|
||||
debug!(
|
||||
"S[{}] req_headers={:?}",
|
||||
self.context_id,
|
||||
self.get_http_request_headers()
|
||||
);
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
|
||||
|
|
@ -819,9 +796,6 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
// Set the model based on the chosen LLM Provider
|
||||
deserialized_body.model = String::from(self.llm_provider().choose_model());
|
||||
|
||||
self.streaming_response = deserialized_body.stream;
|
||||
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
|
||||
deserialized_body.stream_options = Some(StreamOptions {
|
||||
|
|
@ -943,21 +917,15 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
|
||||
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
||||
if !self.chat_completions_request {
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"recv [S={}] bytes={} end_stream={}",
|
||||
self.context_id, body_size, end_of_stream
|
||||
);
|
||||
|
||||
if !self.chat_completions_request {
|
||||
if let Some(body_str) = self
|
||||
.get_http_response_body(0, body_size)
|
||||
.and_then(|bytes| String::from_utf8(bytes).ok())
|
||||
{
|
||||
debug!("recv [S={}] body_str={}", self.context_id, body_str);
|
||||
}
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
if !end_of_stream && !self.streaming_response {
|
||||
return Action::Pause;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,52 +25,6 @@ fn wasm_module() -> String {
|
|||
wasm_file.to_str().unwrap().to_string()
|
||||
}
|
||||
|
||||
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
||||
module
|
||||
.call_proxy_on_request_headers(http_context, 0, false)
|
||||
.expect_get_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-bolt-deterministic-provider"),
|
||||
)
|
||||
.returning(Some("true"))
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-bolt-llm-provider"),
|
||||
Some("openai"),
|
||||
)
|
||||
.expect_get_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-bolt-openai-api-key"),
|
||||
)
|
||||
.returning(Some("api-key"))
|
||||
.expect_replace_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("Authorization"),
|
||||
Some("Bearer api-key"),
|
||||
)
|
||||
.expect_remove_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-bolt-openai-api-key"),
|
||||
)
|
||||
.expect_remove_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-bolt-mistral-api-key"),
|
||||
)
|
||||
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
|
||||
.expect_get_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-bolt-ratelimit-selector"),
|
||||
)
|
||||
.returning(Some("selector-key"))
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("selector-key"))
|
||||
.returning(Some("selector-value"))
|
||||
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
|
||||
.returning(None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
|
|
@ -78,7 +32,28 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
request_headers_expectations(module, http_context);
|
||||
// Request Headers
|
||||
module
|
||||
.call_proxy_on_request_headers(http_context, 0, false)
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
|
||||
.returning(Some("api.openai.com"))
|
||||
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
||||
.returning(Some("/llmrouting"))
|
||||
.expect_replace_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some(":path"),
|
||||
Some("/v1/chat/completions"),
|
||||
)
|
||||
.expect_get_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-katanemo-ratelimit-selector"),
|
||||
)
|
||||
.returning(Some("selector-key"))
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("selector-key"))
|
||||
.returning(Some("selector-value"))
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
||||
// Request Body
|
||||
let chat_completions_request_body = "\
|
||||
|
|
@ -107,8 +82,8 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_http_call(Some("model_server"), None, None, None, None)
|
||||
.returning(Some(1))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
|
|
@ -140,7 +115,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
.expect_http_call(Some("model_server"), None, None, None, None)
|
||||
.returning(Some(2))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -261,7 +235,26 @@ fn successful_request_to_open_ai_chat_completions() {
|
|||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
request_headers_expectations(&mut module, http_context);
|
||||
// Request Headers
|
||||
module
|
||||
.call_proxy_on_request_headers(http_context, 0, false)
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
|
||||
.returning(Some("api.openai.com"))
|
||||
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
||||
.returning(Some("/llmrouting"))
|
||||
.expect_replace_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some(":path"),
|
||||
Some("/v1/chat/completions"),
|
||||
)
|
||||
.expect_get_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-katanemo-ratelimit-selector"),
|
||||
)
|
||||
.returning(None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
||||
// Request Body
|
||||
let chat_completions_request_body = "\
|
||||
|
|
@ -330,7 +323,26 @@ fn bad_request_to_open_ai_chat_completions() {
|
|||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
request_headers_expectations(&mut module, http_context);
|
||||
// Request Headers
|
||||
module
|
||||
.call_proxy_on_request_headers(http_context, 0, false)
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
|
||||
.returning(Some("api.openai.com"))
|
||||
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
||||
.returning(Some("/llmrouting"))
|
||||
.expect_replace_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some(":path"),
|
||||
Some("/v1/chat/completions"),
|
||||
)
|
||||
.expect_get_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-katanemo-ratelimit-selector"),
|
||||
)
|
||||
.returning(None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
||||
// Request Body
|
||||
let incomplete_chat_completions_request_body = "\
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue