diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index 9f81a220..034378a7 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -6,31 +6,27 @@ from dotenv import load_dotenv load_dotenv() -OPENAI_API_KEY=os.getenv("OPENAI_API_KEY") -MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") +OPEN_API_KEY=os.getenv("OPENAI_API_KEY") CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT") MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo") log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) -client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT) +client = OpenAI(api_key=OPEN_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT) def predict(message, history): + # history_openai_format = [] + # for human, assistant in history: + # history_openai_format.append({"role": "user", "content": human }) + # history_openai_format.append({"role": "assistant", "content":assistant}) history.append({"role": "user", "content": message}) log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) log.info("history: ", history) - # Custom headers - custom_headers = { - 'x-bolt-openai-api-key': f"{OPENAI_API_KEY}", - 'x-bolt-mistral-api-key': f"{MISTRAL_API_KEY}", - } - try: response = client.chat.completions.create(model=MODEL_NAME, messages= history, - temperature=1.0, - headers=custom_headers + temperature=1.0 ) except Exception as e: log.info(e) @@ -40,6 +36,10 @@ def predict(message, history): log.info("Error calling gateway API: {}".format(e.message)) raise gr.Error("Error calling gateway API: {}".format(e.message)) + # for chunk in response: + # if chunk.choices[0].delta.content is not None: + # partial_message = partial_message + chunk.choices[0].delta.content + # yield partial_message choices = response.choices message = choices[0].message content = message.content diff --git a/demos/function_calling/docker-compose.yaml b/demos/function_calling/docker-compose.yaml index ad994592..5e05ae3b 100644 --- a/demos/function_calling/docker-compose.yaml +++ b/demos/function_calling/docker-compose.yaml @@ -108,7 +108,6 @@ services: - "18080:8080" environment: - OPENAI_API_KEY=${OPENAI_API_KEY:?error} - - MISTRAL_API_KEY=${MISTRAL_API_KEY:?error} - CHAT_COMPLETION_ENDPOINT=http://bolt:10000/v1 prometheus: diff --git a/envoyfilter/Cargo.lock b/envoyfilter/Cargo.lock index b72c678c..a2647ead 100644 --- a/envoyfilter/Cargo.lock +++ b/envoyfilter/Cargo.lock @@ -745,7 +745,6 @@ dependencies = [ "proxy-wasm", "proxy-wasm-test-framework", "public_types", - "rand", "serde", "serde_json", "serde_yaml", diff --git a/envoyfilter/Cargo.toml b/envoyfilter/Cargo.toml index 69750f5c..a2418486 100644 --- a/envoyfilter/Cargo.toml +++ b/envoyfilter/Cargo.toml @@ -19,7 +19,6 @@ http = "1.1.0" governor = { version = "0.6.3", default-features = false, features = ["no_std"]} tiktoken-rs = "0.5.9" acap = "0.3.0" -rand = "0.8.5" [dev-dependencies] proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" } diff --git a/envoyfilter/envoy.template.yaml b/envoyfilter/envoy.template.yaml index 249d3879..fb653021 100644 --- a/envoyfilter/envoy.template.yaml +++ b/envoyfilter/envoy.template.yaml @@ -19,6 +19,15 @@ static_resources: route_config: name: local_routes virtual_hosts: + - name: openai + domains: + - "api.openai.com" + routes: + - match: + prefix: "/" + route: + auto_host_rewrite: true + cluster: openai - name: local_service domains: - "*" @@ -39,23 +48,28 @@ static_resources: - match: prefix: "/v1/chat/completions" headers: - - name: "x-bolt-llm-provider" - string_match: - exact: openai + name: "Authorization" + present_match: true route: auto_host_rewrite: true cluster: openai timeout: 60s - match: prefix: "/v1/chat/completions" - headers: - - name: "x-bolt-llm-provider" - string_match: - exact: mistral route: auto_host_rewrite: true - cluster: mistral + cluster: mistral_7b_instruct timeout: 60s + - match: + prefix: "/embeddings" + route: + cluster: model_server + - match: + prefix: "/" + direct_response: + status: 200 + body: + inline_string: "Inspect the HTTP header: custom-header.\n" http_filters: - name: envoy.filters.http.wasm typed_config: @@ -108,31 +122,6 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 - - name: mistral - connect_timeout: 5s - dns_lookup_family: V4_ONLY - type: LOGICAL_DNS - lb_policy: ROUND_ROBIN - typed_extension_protocol_options: - envoy.extensions.upstreams.http.v3.HttpProtocolOptions: - "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions - explicit_http_config: - http2_protocol_options: {} - load_assignment: - cluster_name: mistral - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: api.mistral.ai - port_value: 443 - hostname: "api.mistral.ai" - transport_socket: - name: envoy.transport_sockets.tls - typed_config: - "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext - sni: api.mistral.ai - name: model_server connect_timeout: 5s type: STRICT_DNS diff --git a/envoyfilter/envoy.yaml b/envoyfilter/envoy.yaml deleted file mode 100644 index f0236bf6..00000000 --- a/envoyfilter/envoy.yaml +++ /dev/null @@ -1,233 +0,0 @@ -admin: - address: - socket_address: { address: 0.0.0.0, port_value: 9901 } -static_resources: - listeners: - address: - socket_address: - address: 0.0.0.0 - port_value: 10000 - filter_chains: - - filters: - - name: envoy.filters.network.http_connection_manager - typed_config: - "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager - stat_prefix: ingress_http - codec_type: AUTO - scheme_header_transformation: - scheme_to_overwrite: https - route_config: - - name: bolt - domains: - - "*" - routes: - - match: - headers: - - name: "x-bolt-llm-provider" - string_match: - exact: openai - route: - auto_host_rewrite: true - cluster: openai - timeout: 60s - - match: - headers: - - name: "x-bolt-llm-provider" - string_match: - exact: mistral - route: - auto_host_rewrite: true - cluster: mistral - timeout: 60s - - match: - prefix: "/embeddings" - route: - cluster: embeddingserver - http_filters: - - name: envoy.filters.http.wasm - typed_config: - "@type": type.googleapis.com/udpa.type.v1.TypedStruct - type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm - value: - config: - name: "http_config" - configuration: - "@type": "type.googleapis.com/google.protobuf.StringValue" - value: | - default_prompt_endpoint: "127.0.0.1" - load_balancing: "round_robin" - timeout_ms: 5000 - - embedding_provider: - name: "SentenceTransformer" - model: "all-MiniLM-L6-v2" - - llm_providers: - - - name: open-ai-gpt-4 - api_key: "$OPEN_AI_API_KEY" - model: gpt-4 - - - name: mistral_7b_instruct - model: mistral-7b-instruct - endpoint: http://mistral_7b_instruct:10001/v1/chat/completions - default: true - - - prompt_targets: - - - type: context_resolver - name: weather_forecast - few_shot_examples: - - what is the weather in New York? - - how is the weather in San Francisco? - - what is the forecast in Seattle? - entities: - - name: city - required: true - - name: days - endpoint: - cluster: weatherhost - path: /weather - system_prompt: | - You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries: - - Use farenheight for temperature - - Use miles per hour for wind speed - vm_config: - runtime: "envoy.wasm.runtime.v8" - code: - local: - filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm" - - name: envoy.filters.http.router - typed_config: - "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router - clusters: - # LLM Host - # Embedding Providers - # External LLM Providers - - name: openai - connect_timeout: 5s - type: LOGICAL_DNS - lb_policy: ROUND_ROBIN - typed_extension_protocol_options: - envoy.extensions.upstreams.http.v3.HttpProtocolOptions: - "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions - explicit_http_config: - http2_protocol_options: {} - load_assignment: - cluster_name: openai - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: api.openai.com - port_value: 443 - hostname: "api.openai.com" - transport_socket: - name: envoy.transport_sockets.tls - typed_config: - "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext - sni: api.openai.com - common_tls_context: - tls_params: - tls_minimum_protocol_version: TLSv1_2 - tls_maximum_protocol_version: TLSv1_3 - - name: mistral - connect_timeout: 5s - type: LOGICAL_DNS - lb_policy: ROUND_ROBIN - typed_extension_protocol_options: - envoy.extensions.upstreams.http.v3.HttpProtocolOptions: - "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions - explicit_http_config: - http2_protocol_options: {} - load_assignment: - cluster_name: mistral - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: api.mistral.ai - port_value: 443 - hostname: "api.mistral.ai" - transport_socket: - name: envoy.transport_sockets.tls - typed_config: - "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext - sni: api.mistral.ai - common_tls_context: - tls_params: - tls_minimum_protocol_version: TLSv1_2 - tls_maximum_protocol_version: TLSv1_3 - - name: embeddingserver - connect_timeout: 5s - type: STRICT_DNS - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: embeddingserver - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: host.docker.internal - port_value: 8000 - hostname: "embeddingserver" - - name: weatherhost - connect_timeout: 5s - type: STRICT_DNS - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: weatherhost - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: host.docker.internal - port_value: 8000 - hostname: "embeddingserver" - - name: nerhost - connect_timeout: 5s - type: STRICT_DNS - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: nerhost - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: host.docker.internal - port_value: 8000 - hostname: "embeddingserver" - - name: qdrant - connect_timeout: 5s - type: STRICT_DNS - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: qdrant - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: qdrant - port_value: 6333 - hostname: "qdrant" - - name: mistral_7b_instruct - connect_timeout: 5s - type: STRICT_DNS - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: qdrant - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: mistral_7b_instruct - port_value: 10001 - hostname: "mistral_7b_instruct" diff --git a/envoyfilter/src/consts.rs b/envoyfilter/src/consts.rs index 250bc145..6b5f17e2 100644 --- a/envoyfilter/src/consts.rs +++ b/envoyfilter/src/consts.rs @@ -1,11 +1,11 @@ pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5"; pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli"; pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8; -pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-bolt-ratelimit-selector"; +pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector"; pub const SYSTEM_ROLE: &str = "system"; pub const USER_ROLE: &str = "user"; pub const GPT_35_TURBO: &str = "gpt-3.5-turbo"; pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b"; pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes +pub const OPENAI_CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const MODEL_SERVER_NAME: &str = "model_server"; -pub const BOLT_ROUTING_HEADER: &str = "x-bolt-llm-provider"; diff --git a/envoyfilter/src/lib.rs b/envoyfilter/src/lib.rs index a6449695..78c1153d 100644 --- a/envoyfilter/src/lib.rs +++ b/envoyfilter/src/lib.rs @@ -4,9 +4,7 @@ use proxy_wasm::types::*; mod consts; mod filter_context; -mod llm_providers; mod ratelimit; -mod routing; mod stats; mod stream_context; mod tokenizer; diff --git a/envoyfilter/src/llm_providers.rs b/envoyfilter/src/llm_providers.rs deleted file mode 100644 index 91039ed2..00000000 --- a/envoyfilter/src/llm_providers.rs +++ /dev/null @@ -1,47 +0,0 @@ -#[non_exhaustive] -pub struct LlmProviders; - -impl LlmProviders { - pub const OPENAI_PROVIDER: LlmProvider<'static> = LlmProvider { - name: "openai", - api_key_header: "x-bolt-openai-api-key", - model: "gpt-3.5-turbo", - }; - pub const MISTRAL_PROVIDER: LlmProvider<'static> = LlmProvider { - name: "mistral", - api_key_header: "x-bolt-mistral-api-key", - model: "mistral-large-latest", - }; - - pub const VARIANTS: &'static [LlmProvider<'static>] = - &[Self::OPENAI_PROVIDER, Self::MISTRAL_PROVIDER]; -} - -pub struct LlmProvider<'prov> { - name: &'prov str, - api_key_header: &'prov str, - model: &'prov str, -} - -impl AsRef for LlmProvider<'_> { - fn as_ref(&self) -> &str { - self.name - } -} - -impl std::fmt::Display for LlmProvider<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.name) - } -} - -impl LlmProvider<'_> { - pub fn api_key_header(&self) -> &str { - self.api_key_header - } - - pub fn choose_model(&self) -> &str { - // In the future this can be a more complex function balancing reliability, cost, performance, etc. - self.model - } -} diff --git a/envoyfilter/src/routing.rs b/envoyfilter/src/routing.rs deleted file mode 100644 index 5b0f883d..00000000 --- a/envoyfilter/src/routing.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::llm_providers::{LlmProvider, LlmProviders}; -use rand::{seq::SliceRandom, thread_rng}; - -pub fn get_llm_provider<'hostname>(deterministic: bool) -> &'static LlmProvider<'hostname> { - if deterministic { - &LlmProviders::OPENAI_PROVIDER - } else { - let mut rng = thread_rng(); - LlmProviders::VARIANTS - .choose(&mut rng) - .expect("There should always be at least one llm provider") - } -} diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index 5d2bdb5c..b84bcf02 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -1,14 +1,13 @@ use crate::consts::{ - BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, BOLT_ROUTING_HEADER, DEFAULT_EMBEDDING_MODEL, - DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, + BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, + DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, OPENAI_CHAT_COMPLETIONS_PATH, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE, }; use crate::filter_context::{embeddings_store, WasmMetrics}; -use crate::llm_providers::{LlmProvider, LlmProviders}; +use crate::ratelimit; use crate::ratelimit::Header; use crate::stats::IncrementingMetric; use crate::tokenizer; -use crate::{ratelimit, routing}; use acap::cos; use http::StatusCode; use log::{debug, info, warn}; @@ -57,11 +56,11 @@ pub struct StreamContext { pub prompt_targets: Rc>>, pub overrides: Rc>, callouts: HashMap, + host_header: Option, ratelimit_selector: Option
, streaming_response: bool, response_tokens: usize, chat_completions_request: bool, - llm_provider: Option<&'static LlmProvider<'static>>, prompt_guards: Rc>, } @@ -78,39 +77,18 @@ impl StreamContext { metrics, prompt_targets, callouts: HashMap::new(), + host_header: None, ratelimit_selector: None, streaming_response: false, response_tokens: 0, chat_completions_request: false, - llm_provider: None, prompt_guards, overrides, } } - fn llm_provider(&self) -> &LlmProvider { - self.llm_provider - .expect("the provider should be set when asked for it") - } - - fn add_routing_header(&mut self) { - self.add_http_request_header(BOLT_ROUTING_HEADER, self.llm_provider().as_ref()); - } - - fn modify_auth_headers(&mut self) -> Result<(), String> { - let llm_provider_api_key_value = self - .get_http_request_header(self.llm_provider().api_key_header()) - .ok_or(format!("missing {} api key", self.llm_provider()))?; - - let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value); - - self.set_http_request_header("Authorization", Some(&authorization_header_value)); - - // sanitize passed in api keys - for provider in LlmProviders::VARIANTS.iter() { - self.set_http_request_header(provider.api_key_header(), None); - } - - Ok(()) + fn save_host_header(&mut self) { + // Save the host header to be used by filter logic later on. + self.host_header = self.get_http_request_header(":host"); } fn delete_content_length_header(&mut self) { @@ -121,6 +99,19 @@ impl StreamContext { self.set_http_request_header("content-length", None); } + fn modify_path_header(&mut self) { + match self.get_http_request_header(":path") { + // The gateway can start gathering information necessary for routing. For now change the path to an + // OpenAI API path. + Some(path) if path == "/llmrouting" => { + self.set_http_request_header(":path", Some(OPENAI_CHAT_COMPLETIONS_PATH)); + self.chat_completions_request = true; + } + // Otherwise let the filter continue. + _ => (), + } + } + fn save_ratelimit_header(&mut self) { self.ratelimit_selector = self .get_http_request_header(RATELIMIT_SELECTOR_HEADER_KEY) @@ -246,7 +237,6 @@ impl StreamContext { token_id ); - self.metrics.active_http_calls.increment(1); callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent; if self.callouts.insert(token_id, callout_context).is_some() { @@ -441,7 +431,6 @@ impl StreamContext { BOLT_FC_CLUSTER, token_id ); - self.metrics.active_http_calls.increment(1); callout_context.response_handler_type = ResponseHandlerType::FunctionResolver; callout_context.prompt_target_name = Some(prompt_target.name); if self.callouts.insert(token_id, callout_context).is_some() { @@ -449,6 +438,7 @@ impl StreamContext { } } } + self.metrics.active_http_calls.increment(1); } fn function_resolver_handler(&mut self, body: Vec, mut callout_context: CallContext) { @@ -605,7 +595,7 @@ impl StreamContext { }); let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest { - model: callout_context.request_body.model, + model: GPT_35_TURBO.to_string(), messages, tools: None, stream: callout_context.request_body.stream, @@ -761,24 +751,11 @@ impl HttpContext for StreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - let provider_hint = self - .get_http_request_header("x-bolt-deterministic-provider") - .is_some(); - self.llm_provider = Some(routing::get_llm_provider(provider_hint)); - - self.add_routing_header(); - if let Err(error) = self.modify_auth_headers() { - self.send_server_error(error, Some(StatusCode::BAD_REQUEST)); - } + self.save_host_header(); self.delete_content_length_header(); + self.modify_path_header(); self.save_ratelimit_header(); - debug!( - "S[{}] req_headers={:?}", - self.context_id, - self.get_http_request_headers() - ); - Action::Continue } @@ -819,9 +796,6 @@ impl HttpContext for StreamContext { } }; - // Set the model based on the chosen LLM Provider - deserialized_body.model = String::from(self.llm_provider().choose_model()); - self.streaming_response = deserialized_body.stream; if deserialized_body.stream && deserialized_body.stream_options.is_none() { deserialized_body.stream_options = Some(StreamOptions { @@ -943,21 +917,15 @@ impl HttpContext for StreamContext { } fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + if !self.chat_completions_request { + return Action::Continue; + } + debug!( "recv [S={}] bytes={} end_stream={}", self.context_id, body_size, end_of_stream ); - if !self.chat_completions_request { - if let Some(body_str) = self - .get_http_response_body(0, body_size) - .and_then(|bytes| String::from_utf8(bytes).ok()) - { - debug!("recv [S={}] body_str={}", self.context_id, body_str); - } - return Action::Continue; - } - if !end_of_stream && !self.streaming_response { return Action::Pause; } diff --git a/envoyfilter/tests/integration.rs b/envoyfilter/tests/integration.rs index f45cde7c..cb8646f5 100644 --- a/envoyfilter/tests/integration.rs +++ b/envoyfilter/tests/integration.rs @@ -25,52 +25,6 @@ fn wasm_module() -> String { wasm_file.to_str().unwrap().to_string() } -fn request_headers_expectations(module: &mut Tester, http_context: i32) { - module - .call_proxy_on_request_headers(http_context, 0, false) - .expect_get_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-bolt-deterministic-provider"), - ) - .returning(Some("true")) - .expect_add_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-bolt-llm-provider"), - Some("openai"), - ) - .expect_get_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-bolt-openai-api-key"), - ) - .returning(Some("api-key")) - .expect_replace_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("Authorization"), - Some("Bearer api-key"), - ) - .expect_remove_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-bolt-openai-api-key"), - ) - .expect_remove_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-bolt-mistral-api-key"), - ) - .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) - .expect_get_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-bolt-ratelimit-selector"), - ) - .returning(Some("selector-key")) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("selector-key")) - .returning(Some("selector-value")) - .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) - .returning(None) - .expect_log(Some(LogLevel::Debug), None) - .execute_and_expect(ReturnType::Action(Action::Continue)) - .unwrap(); -} - fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { module .call_proxy_on_context_create(http_context, filter_context) @@ -78,7 +32,28 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .execute_and_expect(ReturnType::None) .unwrap(); - request_headers_expectations(module, http_context); + // Request Headers + module + .call_proxy_on_request_headers(http_context, 0, false) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host")) + .returning(Some("api.openai.com")) + .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) + .returning(Some("/llmrouting")) + .expect_replace_header_map_value( + Some(MapType::HttpRequestHeaders), + Some(":path"), + Some("/v1/chat/completions"), + ) + .expect_get_header_map_value( + Some(MapType::HttpRequestHeaders), + Some("x-katanemo-ratelimit-selector"), + ) + .returning(Some("selector-key")) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("selector-key")) + .returning(Some("selector-value")) + .execute_and_expect(ReturnType::Action(Action::Continue)) + .unwrap(); // Request Body let chat_completions_request_body = "\ @@ -107,8 +82,8 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { // The actual call is not important in this test, we just need to grab the token_id .expect_http_call(Some("model_server"), None, None, None, None) .returning(Some(1)) - .expect_log(Some(LogLevel::Debug), None) .expect_metric_increment("active_http_calls", 1) + .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Info), None) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -140,7 +115,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_http_call(Some("model_server"), None, None, None, None) .returning(Some(2)) .expect_metric_increment("active_http_calls", 1) - .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) .unwrap(); @@ -261,7 +235,26 @@ fn successful_request_to_open_ai_chat_completions() { .execute_and_expect(ReturnType::None) .unwrap(); - request_headers_expectations(&mut module, http_context); + // Request Headers + module + .call_proxy_on_request_headers(http_context, 0, false) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host")) + .returning(Some("api.openai.com")) + .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) + .returning(Some("/llmrouting")) + .expect_replace_header_map_value( + Some(MapType::HttpRequestHeaders), + Some(":path"), + Some("/v1/chat/completions"), + ) + .expect_get_header_map_value( + Some(MapType::HttpRequestHeaders), + Some("x-katanemo-ratelimit-selector"), + ) + .returning(None) + .execute_and_expect(ReturnType::Action(Action::Continue)) + .unwrap(); // Request Body let chat_completions_request_body = "\ @@ -330,7 +323,26 @@ fn bad_request_to_open_ai_chat_completions() { .execute_and_expect(ReturnType::None) .unwrap(); - request_headers_expectations(&mut module, http_context); + // Request Headers + module + .call_proxy_on_request_headers(http_context, 0, false) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host")) + .returning(Some("api.openai.com")) + .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) + .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) + .returning(Some("/llmrouting")) + .expect_replace_header_map_value( + Some(MapType::HttpRequestHeaders), + Some(":path"), + Some("/v1/chat/completions"), + ) + .expect_get_header_map_value( + Some(MapType::HttpRequestHeaders), + Some("x-katanemo-ratelimit-selector"), + ) + .returning(None) + .execute_and_expect(ReturnType::Action(Action::Continue)) + .unwrap(); // Request Body let incomplete_chat_completions_request_body = "\