diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index 034378a7..4d06287d 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -6,27 +6,32 @@ from dotenv import load_dotenv load_dotenv() -OPEN_API_KEY=os.getenv("OPENAI_API_KEY") +OPENAI_API_KEY=os.getenv("OPENAI_API_KEY") +MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT") MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo") log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) -client = OpenAI(api_key=OPEN_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT) +client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT) def predict(message, history): - # history_openai_format = [] - # for human, assistant in history: - # history_openai_format.append({"role": "user", "content": human }) - # history_openai_format.append({"role": "assistant", "content":assistant}) history.append({"role": "user", "content": message}) log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) log.info("history: ", history) + # Custom headers + custom_headers = { + 'x-bolt-openai-api-key': f"{OPENAI_API_KEY}", + 'x-bolt-mistral-api-key': f"{MISTRAL_API_KEY}", + 'x-bolt-deterministic-provider': 'openai', + } + try: response = client.chat.completions.create(model=MODEL_NAME, messages= history, - temperature=1.0 + temperature=1.0, + extra_headers=custom_headers ) except Exception as e: log.info(e) @@ -36,10 +41,6 @@ def predict(message, history): log.info("Error calling gateway API: {}".format(e.message)) raise gr.Error("Error calling gateway API: {}".format(e.message)) - # for chunk in response: - # if chunk.choices[0].delta.content is not None: - # partial_message = partial_message + chunk.choices[0].delta.content - # yield partial_message choices = response.choices message = choices[0].message content = message.content diff --git a/demos/function_calling/docker-compose.yaml b/demos/function_calling/docker-compose.yaml index 5e05ae3b..ad994592 100644 --- a/demos/function_calling/docker-compose.yaml +++ b/demos/function_calling/docker-compose.yaml @@ -108,6 +108,7 @@ services: - "18080:8080" environment: - OPENAI_API_KEY=${OPENAI_API_KEY:?error} + - MISTRAL_API_KEY=${MISTRAL_API_KEY:?error} - CHAT_COMPLETION_ENDPOINT=http://bolt:10000/v1 prometheus: diff --git a/envoyfilter/Cargo.lock b/envoyfilter/Cargo.lock index a2647ead..b72c678c 100644 --- a/envoyfilter/Cargo.lock +++ b/envoyfilter/Cargo.lock @@ -745,6 +745,7 @@ dependencies = [ "proxy-wasm", "proxy-wasm-test-framework", "public_types", + "rand", "serde", "serde_json", "serde_yaml", diff --git a/envoyfilter/Cargo.toml b/envoyfilter/Cargo.toml index a2418486..69750f5c 100644 --- a/envoyfilter/Cargo.toml +++ b/envoyfilter/Cargo.toml @@ -19,6 +19,7 @@ http = "1.1.0" governor = { version = "0.6.3", default-features = false, features = ["no_std"]} tiktoken-rs = "0.5.9" acap = "0.3.0" +rand = "0.8.5" [dev-dependencies] proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" } diff --git a/envoyfilter/envoy.template.yaml b/envoyfilter/envoy.template.yaml index fb653021..249d3879 100644 --- a/envoyfilter/envoy.template.yaml +++ b/envoyfilter/envoy.template.yaml @@ -19,15 +19,6 @@ static_resources: route_config: name: local_routes virtual_hosts: - - name: openai - domains: - - "api.openai.com" - routes: - - match: - prefix: "/" - route: - auto_host_rewrite: true - cluster: openai - name: local_service domains: - "*" @@ -48,28 +39,23 @@ static_resources: - match: prefix: "/v1/chat/completions" headers: - name: "Authorization" - present_match: true + - name: "x-bolt-llm-provider" + string_match: + exact: openai route: auto_host_rewrite: true cluster: openai timeout: 60s - match: prefix: "/v1/chat/completions" + headers: + - name: "x-bolt-llm-provider" + string_match: + exact: mistral route: auto_host_rewrite: true - cluster: mistral_7b_instruct + cluster: mistral timeout: 60s - - match: - prefix: "/embeddings" - route: - cluster: model_server - - match: - prefix: "/" - direct_response: - status: 200 - body: - inline_string: "Inspect the HTTP header: custom-header.\n" http_filters: - name: envoy.filters.http.wasm typed_config: @@ -122,6 +108,31 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + - name: mistral + connect_timeout: 5s + dns_lookup_family: V4_ONLY + type: LOGICAL_DNS + lb_policy: ROUND_ROBIN + typed_extension_protocol_options: + envoy.extensions.upstreams.http.v3.HttpProtocolOptions: + "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions + explicit_http_config: + http2_protocol_options: {} + load_assignment: + cluster_name: mistral + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: api.mistral.ai + port_value: 443 + hostname: "api.mistral.ai" + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + sni: api.mistral.ai - name: model_server connect_timeout: 5s type: STRICT_DNS diff --git a/envoyfilter/envoy.yaml b/envoyfilter/envoy.yaml new file mode 100644 index 00000000..f0236bf6 --- /dev/null +++ b/envoyfilter/envoy.yaml @@ -0,0 +1,233 @@ +admin: + address: + socket_address: { address: 0.0.0.0, port_value: 9901 } +static_resources: + listeners: + address: + socket_address: + address: 0.0.0.0 + port_value: 10000 + filter_chains: + - filters: + - name: envoy.filters.network.http_connection_manager + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager + stat_prefix: ingress_http + codec_type: AUTO + scheme_header_transformation: + scheme_to_overwrite: https + route_config: + - name: bolt + domains: + - "*" + routes: + - match: + headers: + - name: "x-bolt-llm-provider" + string_match: + exact: openai + route: + auto_host_rewrite: true + cluster: openai + timeout: 60s + - match: + headers: + - name: "x-bolt-llm-provider" + string_match: + exact: mistral + route: + auto_host_rewrite: true + cluster: mistral + timeout: 60s + - match: + prefix: "/embeddings" + route: + cluster: embeddingserver + http_filters: + - name: envoy.filters.http.wasm + typed_config: + "@type": type.googleapis.com/udpa.type.v1.TypedStruct + type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm + value: + config: + name: "http_config" + configuration: + "@type": "type.googleapis.com/google.protobuf.StringValue" + value: | + default_prompt_endpoint: "127.0.0.1" + load_balancing: "round_robin" + timeout_ms: 5000 + + embedding_provider: + name: "SentenceTransformer" + model: "all-MiniLM-L6-v2" + + llm_providers: + + - name: open-ai-gpt-4 + api_key: "$OPEN_AI_API_KEY" + model: gpt-4 + + - name: mistral_7b_instruct + model: mistral-7b-instruct + endpoint: http://mistral_7b_instruct:10001/v1/chat/completions + default: true + + + prompt_targets: + + - type: context_resolver + name: weather_forecast + few_shot_examples: + - what is the weather in New York? + - how is the weather in San Francisco? + - what is the forecast in Seattle? + entities: + - name: city + required: true + - name: days + endpoint: + cluster: weatherhost + path: /weather + system_prompt: | + You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries: + - Use farenheight for temperature + - Use miles per hour for wind speed + vm_config: + runtime: "envoy.wasm.runtime.v8" + code: + local: + filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm" + - name: envoy.filters.http.router + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router + clusters: + # LLM Host + # Embedding Providers + # External LLM Providers + - name: openai + connect_timeout: 5s + type: LOGICAL_DNS + lb_policy: ROUND_ROBIN + typed_extension_protocol_options: + envoy.extensions.upstreams.http.v3.HttpProtocolOptions: + "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions + explicit_http_config: + http2_protocol_options: {} + load_assignment: + cluster_name: openai + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: api.openai.com + port_value: 443 + hostname: "api.openai.com" + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + sni: api.openai.com + common_tls_context: + tls_params: + tls_minimum_protocol_version: TLSv1_2 + tls_maximum_protocol_version: TLSv1_3 + - name: mistral + connect_timeout: 5s + type: LOGICAL_DNS + lb_policy: ROUND_ROBIN + typed_extension_protocol_options: + envoy.extensions.upstreams.http.v3.HttpProtocolOptions: + "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions + explicit_http_config: + http2_protocol_options: {} + load_assignment: + cluster_name: mistral + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: api.mistral.ai + port_value: 443 + hostname: "api.mistral.ai" + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + sni: api.mistral.ai + common_tls_context: + tls_params: + tls_minimum_protocol_version: TLSv1_2 + tls_maximum_protocol_version: TLSv1_3 + - name: embeddingserver + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: embeddingserver + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: host.docker.internal + port_value: 8000 + hostname: "embeddingserver" + - name: weatherhost + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: weatherhost + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: host.docker.internal + port_value: 8000 + hostname: "embeddingserver" + - name: nerhost + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: nerhost + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: host.docker.internal + port_value: 8000 + hostname: "embeddingserver" + - name: qdrant + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: qdrant + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: qdrant + port_value: 6333 + hostname: "qdrant" + - name: mistral_7b_instruct + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: qdrant + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: mistral_7b_instruct + port_value: 10001 + hostname: "mistral_7b_instruct" diff --git a/envoyfilter/src/consts.rs b/envoyfilter/src/consts.rs index 6b5f17e2..250bc145 100644 --- a/envoyfilter/src/consts.rs +++ b/envoyfilter/src/consts.rs @@ -1,11 +1,11 @@ pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5"; pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli"; pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8; -pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector"; +pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-bolt-ratelimit-selector"; pub const SYSTEM_ROLE: &str = "system"; pub const USER_ROLE: &str = "user"; pub const GPT_35_TURBO: &str = "gpt-3.5-turbo"; pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b"; pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes -pub const OPENAI_CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const MODEL_SERVER_NAME: &str = "model_server"; +pub const BOLT_ROUTING_HEADER: &str = "x-bolt-llm-provider"; diff --git a/envoyfilter/src/lib.rs b/envoyfilter/src/lib.rs index 78c1153d..a6449695 100644 --- a/envoyfilter/src/lib.rs +++ b/envoyfilter/src/lib.rs @@ -4,7 +4,9 @@ use proxy_wasm::types::*; mod consts; mod filter_context; +mod llm_providers; mod ratelimit; +mod routing; mod stats; mod stream_context; mod tokenizer; diff --git a/envoyfilter/src/llm_providers.rs b/envoyfilter/src/llm_providers.rs new file mode 100644 index 00000000..91039ed2 --- /dev/null +++ b/envoyfilter/src/llm_providers.rs @@ -0,0 +1,47 @@ +#[non_exhaustive] +pub struct LlmProviders; + +impl LlmProviders { + pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider { + name: "openai", + api_key_header: "x-bolt-openai-api-key", + model: "gpt-3.5-turbo", + }; + pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider { + name: "mistral", + api_key_header: "x-bolt-mistral-api-key", + model: "mistral-large-latest", + }; + + pub const VARIANTS: &'static [LlmProvider<'static>] = + &[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER]; +} + +pub struct LlmProvider<'prov> { + name: &'prov str, + api_key_header: &'prov str, + model: &'prov str, +} + +impl AsRef for LlmProvider<'_> { + fn as_ref(&self) -> &str { + self.name + } +} + +impl std::fmt::Display for LlmProvider<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +impl LlmProvider<'_> { + pub fn api_key_header(&self) -> &str { + self.api_key_header + } + + pub fn choose_model(&self) -> &str { + // In the future this can be a more complex function balancing reliability, cost, performance, etc. + self.model + } +} diff --git a/envoyfilter/src/routing.rs b/envoyfilter/src/routing.rs new file mode 100644 index 00000000..5b0f883d --- /dev/null +++ b/envoyfilter/src/routing.rs @@ -0,0 +1,13 @@ +use crate::llm_providers::{LlmProvider, LlmProviders}; +use rand::{seq::SliceRandom, thread_rng}; + +pub fn get_llm_provider<'hostname>(deterministic: bool) -> &'static LlmProvider<'hostname> { + if deterministic { + &LlmProviders::OPENAI_PROVIDER + } else { + let mut rng = thread_rng(); + LlmProviders::VARIANTS + .choose(&mut rng) + .expect("There should always be at least one llm provider") + } +} diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index b84bcf02..5d2bdb5c 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -1,13 +1,14 @@ use crate::consts::{ - BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, - DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, OPENAI_CHAT_COMPLETIONS_PATH, + BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, BOLT_ROUTING_HEADER, DEFAULT_EMBEDDING_MODEL, + DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE, }; use crate::filter_context::{embeddings_store, WasmMetrics}; -use crate::ratelimit; +use crate::llm_providers::{LlmProvider, LlmProviders}; use crate::ratelimit::Header; use crate::stats::IncrementingMetric; use crate::tokenizer; +use crate::{ratelimit, routing}; use acap::cos; use http::StatusCode; use log::{debug, info, warn}; @@ -56,11 +57,11 @@ pub struct StreamContext { pub prompt_targets: Rc>>, pub overrides: Rc>, callouts: HashMap, - host_header: Option, ratelimit_selector: Option
, streaming_response: bool, response_tokens: usize, chat_completions_request: bool, + llm_provider: Option<&'static LlmProvider<'static>>, prompt_guards: Rc>, } @@ -77,18 +78,39 @@ impl StreamContext { metrics, prompt_targets, callouts: HashMap::new(), - host_header: None, ratelimit_selector: None, streaming_response: false, response_tokens: 0, chat_completions_request: false, + llm_provider: None, prompt_guards, overrides, } } - fn save_host_header(&mut self) { - // Save the host header to be used by filter logic later on. - self.host_header = self.get_http_request_header(":host"); + fn llm_provider(&self) -> &LlmProvider { + self.llm_provider + .expect("the provider should be set when asked for it") + } + + fn add_routing_header(&mut self) { + self.add_http_request_header(BOLT_ROUTING_HEADER, self.llm_provider().as_ref()); + } + + fn modify_auth_headers(&mut self) -> Result<(), String> { + let llm_provider_api_key_value = self + .get_http_request_header(self.llm_provider().api_key_header()) + .ok_or(format!("missing {} api key", self.llm_provider()))?; + + let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value); + + self.set_http_request_header("Authorization", Some(&authorization_header_value)); + + // sanitize passed in api keys + for provider in LlmProviders::VARIANTS.iter() { + self.set_http_request_header(provider.api_key_header(), None); + } + + Ok(()) } fn delete_content_length_header(&mut self) { @@ -99,19 +121,6 @@ impl StreamContext { self.set_http_request_header("content-length", None); } - fn modify_path_header(&mut self) { - match self.get_http_request_header(":path") { - // The gateway can start gathering information necessary for routing. For now change the path to an - // OpenAI API path. - Some(path) if path == "/llmrouting" => { - self.set_http_request_header(":path", Some(OPENAI_CHAT_COMPLETIONS_PATH)); - self.chat_completions_request = true; - } - // Otherwise let the filter continue. - _ => (), - } - } - fn save_ratelimit_header(&mut self) { self.ratelimit_selector = self .get_http_request_header(RATELIMIT_SELECTOR_HEADER_KEY) @@ -237,6 +246,7 @@ impl StreamContext { token_id ); + self.metrics.active_http_calls.increment(1); callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent; if self.callouts.insert(token_id, callout_context).is_some() { @@ -431,6 +441,7 @@ impl StreamContext { BOLT_FC_CLUSTER, token_id ); + self.metrics.active_http_calls.increment(1); callout_context.response_handler_type = ResponseHandlerType::FunctionResolver; callout_context.prompt_target_name = Some(prompt_target.name); if self.callouts.insert(token_id, callout_context).is_some() { @@ -438,7 +449,6 @@ impl StreamContext { } } } - self.metrics.active_http_calls.increment(1); } fn function_resolver_handler(&mut self, body: Vec, mut callout_context: CallContext) { @@ -595,7 +605,7 @@ impl StreamContext { }); let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest { - model: GPT_35_TURBO.to_string(), + model: callout_context.request_body.model, messages, tools: None, stream: callout_context.request_body.stream, @@ -751,11 +761,24 @@ impl HttpContext for StreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - self.save_host_header(); + let provider_hint = self + .get_http_request_header("x-bolt-deterministic-provider") + .is_some(); + self.llm_provider = Some(routing::get_llm_provider(provider_hint)); + + self.add_routing_header(); + if let Err(error) = self.modify_auth_headers() { + self.send_server_error(error, Some(StatusCode::BAD_REQUEST)); + } self.delete_content_length_header(); - self.modify_path_header(); self.save_ratelimit_header(); + debug!( + "S[{}] req_headers={:?}", + self.context_id, + self.get_http_request_headers() + ); + Action::Continue } @@ -796,6 +819,9 @@ impl HttpContext for StreamContext { } }; + // Set the model based on the chosen LLM Provider + deserialized_body.model = String::from(self.llm_provider().choose_model()); + self.streaming_response = deserialized_body.stream; if deserialized_body.stream && deserialized_body.stream_options.is_none() { deserialized_body.stream_options = Some(StreamOptions { @@ -917,15 +943,21 @@ impl HttpContext for StreamContext { } fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { - if !self.chat_completions_request { - return Action::Continue; - } - debug!( "recv [S={}] bytes={} end_stream={}", self.context_id, body_size, end_of_stream ); + if !self.chat_completions_request { + if let Some(body_str) = self + .get_http_response_body(0, body_size) + .and_then(|bytes| String::from_utf8(bytes).ok()) + { + debug!("recv [S={}] body_str={}", self.context_id, body_str); + } + return Action::Continue; + } + if !end_of_stream && !self.streaming_response { return Action::Pause; } diff --git a/envoyfilter/tests/integration.rs b/envoyfilter/tests/integration.rs index cb8646f5..f45cde7c 100644 --- a/envoyfilter/tests/integration.rs +++ b/envoyfilter/tests/integration.rs @@ -25,6 +25,52 @@ fn wasm_module() -> String { wasm_file.to_str().unwrap().to_string() } +fn request_headers_expectations(module: &mut Tester, http_context: i32) { + module + .call_proxy_on_request_headers(http_context, 0, false) + .expect_get_header_map_value( + Some(MapType::HttpRequestHeaders), + Some("x-bolt-deterministic-provider"), + ) + .returning(Some("true")) + .expect_add_header_map_value( + Some(MapType::HttpRequestHeaders), + Some("x-bolt-llm-provider"), + Some("openai"), + ) + .expect_get_header_map_value( + Some(MapType::HttpRequestHeaders), + Some("x-bolt-openai-api-key"), + ) + .returning(Some("api-key")) + .expect_replace_header_map_value( + Some(MapType::HttpRequestHeaders), + Some("Authorization"), + Some("Bearer api-key"), + ) + .expect_remove_header_map_value( + Some(MapType::HttpRequestHeaders), + Some("x-bolt-openai-api-key"), + ) + .expect_remove_header_map_value( + Some(MapType::HttpRequestHeaders), + Some("x-bolt-mistral-api-key"), + ) + .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) + .expect_get_header_map_value( + Some(MapType::HttpRequestHeaders), + Some("x-bolt-ratelimit-selector"), + ) + .returning(Some("selector-key")) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("selector-key")) + .returning(Some("selector-value")) + .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) + .returning(None) + .expect_log(Some(LogLevel::Debug), None) + .execute_and_expect(ReturnType::Action(Action::Continue)) + .unwrap(); +} + fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { module .call_proxy_on_context_create(http_context, filter_context) @@ -32,28 +78,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .execute_and_expect(ReturnType::None) .unwrap(); - // Request Headers - module - .call_proxy_on_request_headers(http_context, 0, false) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host")) - .returning(Some("api.openai.com")) - .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) - .returning(Some("/llmrouting")) - .expect_replace_header_map_value( - Some(MapType::HttpRequestHeaders), - Some(":path"), - Some("/v1/chat/completions"), - ) - .expect_get_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-katanemo-ratelimit-selector"), - ) - .returning(Some("selector-key")) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("selector-key")) - .returning(Some("selector-value")) - .execute_and_expect(ReturnType::Action(Action::Continue)) - .unwrap(); + request_headers_expectations(module, http_context); // Request Body let chat_completions_request_body = "\ @@ -82,8 +107,8 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { // The actual call is not important in this test, we just need to grab the token_id .expect_http_call(Some("model_server"), None, None, None, None) .returning(Some(1)) - .expect_metric_increment("active_http_calls", 1) .expect_log(Some(LogLevel::Debug), None) + .expect_metric_increment("active_http_calls", 1) .expect_log(Some(LogLevel::Info), None) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -115,6 +140,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_http_call(Some("model_server"), None, None, None, None) .returning(Some(2)) .expect_metric_increment("active_http_calls", 1) + .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) .unwrap(); @@ -235,26 +261,7 @@ fn successful_request_to_open_ai_chat_completions() { .execute_and_expect(ReturnType::None) .unwrap(); - // Request Headers - module - .call_proxy_on_request_headers(http_context, 0, false) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host")) - .returning(Some("api.openai.com")) - .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) - .returning(Some("/llmrouting")) - .expect_replace_header_map_value( - Some(MapType::HttpRequestHeaders), - Some(":path"), - Some("/v1/chat/completions"), - ) - .expect_get_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-katanemo-ratelimit-selector"), - ) - .returning(None) - .execute_and_expect(ReturnType::Action(Action::Continue)) - .unwrap(); + request_headers_expectations(&mut module, http_context); // Request Body let chat_completions_request_body = "\ @@ -323,26 +330,7 @@ fn bad_request_to_open_ai_chat_completions() { .execute_and_expect(ReturnType::None) .unwrap(); - // Request Headers - module - .call_proxy_on_request_headers(http_context, 0, false) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host")) - .returning(Some("api.openai.com")) - .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) - .returning(Some("/llmrouting")) - .expect_replace_header_map_value( - Some(MapType::HttpRequestHeaders), - Some(":path"), - Some("/v1/chat/completions"), - ) - .expect_get_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-katanemo-ratelimit-selector"), - ) - .returning(None) - .execute_and_expect(ReturnType::Action(Action::Continue)) - .unwrap(); + request_headers_expectations(&mut module, http_context); // Request Body let incomplete_chat_completions_request_body = "\