diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 868c7548..1e577bbf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: name: cargo-fmt language: system types: [file, rust] - entry: bash -c "cd crates/llm_gateway && cargo fmt -- --check" + entry: bash -c "cd crates/llm_gateway && cargo fmt" - id: cargo-clippy name: cargo-clippy diff --git a/crates/llm_gateway/src/llm_stream_context.rs b/crates/llm_gateway/src/llm_stream_context.rs index 57bc25ed..6c585a72 100644 --- a/crates/llm_gateway/src/llm_stream_context.rs +++ b/crates/llm_gateway/src/llm_stream_context.rs @@ -164,16 +164,12 @@ impl HttpContext for LlmGatewayStreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - debug!("on_http_request_headers 1"); self.select_llm_provider(); - debug!("on_http_request_headers 2"); self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name); - debug!("on_http_request_headers 3"); if let Err(error) = self.modify_auth_headers() { self.send_server_error(error, Some(StatusCode::BAD_REQUEST)); } - debug!("on_http_request_headers 4"); self.delete_content_length_header(); self.save_ratelimit_header(); @@ -231,8 +227,6 @@ impl HttpContext for LlmGatewayStreamContext { }; self.is_chat_completions_request = true; - debug!("llm gateway mode, skipping over all prompt targets"); - // remove metadata from the request body deserialized_body.metadata = None; // delete model key from message array diff --git a/crates/llm_gateway/tests/integration.rs b/crates/llm_gateway/tests/integration.rs index 5821a79a..80ff8d9f 100644 --- a/crates/llm_gateway/tests/integration.rs +++ b/crates/llm_gateway/tests/integration.rs @@ -1,19 +1,9 @@ -use common::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage}; -use common::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType}; -use common::common_types::{HallucinationClassificationResponse, PromptGuardResponse}; -use common::embeddings::{ - create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, - Embedding, -}; -use common::{common_types::ZeroShotClassificationResponse, configuration::Configuration}; use http::StatusCode; use proxy_wasm_test_framework::tester::{self, Tester}; use proxy_wasm_test_framework::types::{ Action, BufferType, LogLevel, MapType, MetricType, ReturnType, }; -use serde_yaml::Value; use serial_test::serial; -use std::collections::HashMap; use std::path::Path; fn wasm_module() -> String { @@ -34,11 +24,6 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) { ) .returning(Some("default")) .expect_log(Some(LogLevel::Debug), None) - .expect_add_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-arch-upstream"), - Some("arch_llm_listener"), - ) .expect_add_header_map_value( Some(MapType::HttpRequestHeaders), Some("x-arch-llm-provider"), @@ -61,6 +46,8 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) { .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) .returning(Some("/v1/chat/completions")) + .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) + .returning(None) .expect_log(Some(LogLevel::Debug), None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id")) .returning(None) @@ -76,181 +63,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .unwrap(); request_headers_expectations(module, http_context); - - // Request Body - let chat_completions_request_body = "\ -{\ - \"messages\": [\ - {\ - \"role\": \"system\",\ - \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ - },\ - {\ - \"role\": \"user\",\ - \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ - }\ - ],\ - \"model\": \"gpt-4\"\ -}"; - - module - .call_proxy_on_request_body( - http_context, - chat_completions_request_body.len() as i32, - true, - ) - .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) - .returning(Some(chat_completions_request_body)) - // The actual call is not important in this test, we just need to grab the token_id - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/guard"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(1)) - .expect_log(Some(LogLevel::Debug), None) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::Action(Action::Pause)) - .unwrap(); - - let prompt_guard_response = PromptGuardResponse { - toxic_prob: None, - toxic_verdict: None, - jailbreak_prob: None, - jailbreak_verdict: None, - }; - let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap(); - module - .call_proxy_on_http_call_response( - http_context, - 1, - 0, - prompt_guard_response_buffer.len() as i32, - 0, - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&prompt_guard_response_buffer)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(2)) - .expect_metric_increment("active_http_calls", 1) - .expect_log(Some(LogLevel::Debug), None) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let embedding_response = CreateEmbeddingResponse { - data: vec![Embedding { - index: 0, - embedding: vec![], - object: embedding::Object::default(), - }], - model: String::from("test"), - object: create_embedding_response::Object::default(), - usage: Box::new(CreateEmbeddingResponseUsage::new(0, 0)), - }; - let embeddings_response_buffer = serde_json::to_string(&embedding_response).unwrap(); - module - .call_proxy_on_http_call_response( - http_context, - 2, - 0, - embeddings_response_buffer.len() as i32, - 0, - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&embeddings_response_buffer)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/zeroshot"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(3)) - .expect_metric_increment("active_http_calls", 1) - .expect_log(Some(LogLevel::Debug), None) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let zero_shot_response = ZeroShotClassificationResponse { - predicted_class: "weather_forecast".to_string(), - predicted_class_score: 0.1, - scores: HashMap::new(), - model: "test-model".to_string(), - }; - let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap(); - module - .call_proxy_on_http_call_response( - http_context, - 3, - 0, - zeroshot_intent_detection_buffer.len() as i32, - 0, - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&zeroshot_intent_detection_buffer)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Info), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - (":method", "POST"), - ("x-arch-upstream", "arch_fc"), - (":path", "/v1/chat/completions"), - (":authority", "arch_fc"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "120000"), - ]), - None, - None, - None, - ) - .returning(Some(4)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); } fn setup_filter(module: &mut Tester, config: &str) -> i32 { @@ -270,69 +82,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { .execute_and_expect(ReturnType::Bool(true)) .unwrap(); - module - .call_proxy_on_tick(filter_context) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(101)) - .expect_metric_increment("active_http_calls", 1) - .expect_set_tick_period_millis(Some(0)) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let embedding_response = CreateEmbeddingResponse { - data: vec![Embedding { - embedding: vec![], - index: 0, - object: embedding::Object::default(), - }], - model: String::from("test"), - object: create_embedding_response::Object::default(), - usage: Box::new(CreateEmbeddingResponseUsage { - prompt_tokens: 0, - total_tokens: 0, - }), - }; - let embedding_response_str = serde_json::to_string(&embedding_response).unwrap(); - module - .call_proxy_on_http_call_response( - filter_context, - 101, - 0, - embedding_response_str.len() as i32, - 0, - ) - .expect_log( - Some(LogLevel::Debug), - Some( - format!( - "filter_context: on_http_call_response called with token_id: {:?}", - 101 - ) - .as_str(), - ), - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&embedding_response_str)) - .expect_log(Some(LogLevel::Debug), None) - .execute_and_expect(ReturnType::None) - .unwrap(); - filter_context } @@ -357,6 +106,10 @@ llm_providers: access_key: secret_key model: gpt-4 default: true + - name: open-ai-gpt-4o + provider: openai + access_key: secret_key + model: gpt-4o overrides: # confidence threshold for prompt target intent matching @@ -396,7 +149,7 @@ ratelimits: key: selector-key value: selector-value limit: - tokens: 1 + tokens: 50 unit: minute "# } @@ -440,7 +193,7 @@ fn successful_request_to_open_ai_chat_completions() { },\ {\ \"role\": \"user\",\ - \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + \"content\": \"Compose a poem.\"\ }\ ],\ \"model\": \"gpt-4\"\ @@ -455,10 +208,10 @@ fn successful_request_to_open_ai_chat_completions() { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) .expect_log(Some(LogLevel::Debug), None) - .expect_http_call(Some("arch_internal"), None, None, None, None) - .returning(Some(4)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::Action(Action::Pause)) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) + .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); } @@ -547,111 +300,35 @@ fn request_ratelimited() { normal_flow(&mut module, filter_context, http_context); - let arch_fc_resp = ChatCompletionsResponse { - usage: Some(Usage { - completion_tokens: 0, - }), - choices: vec![Choice { - finish_reason: "test".to_string(), - index: 0, - message: Message { - role: "system".to_string(), - content: None, - tool_calls: Some(vec![ToolCall { - id: String::from("test"), - tool_type: ToolType::Function, - function: FunctionCallDetail { - name: String::from("weather_forecast"), - arguments: HashMap::from([( - String::from("city"), - Value::String(String::from("seattle")), - )]), - }, - }]), - model: None, - }, - }], - model: String::from("test"), - metadata: None, - }; + // Request Body + let chat_completions_request_body = "\ +{\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming. Compose a poem that explains the concept of recursion in programming. Compose a poem that explains the concept of recursion in programming. \"\ + }\ + ],\ + \"model\": \"gpt-4\"\ +}"; - let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); module - .call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&arch_fc_resp_str)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/hallucination"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, + .call_proxy_on_request_body( + http_context, + chat_completions_request_body.len() as i32, + true, ) - .returning(Some(5)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let hallucatination_body = HallucinationClassificationResponse { - params_scores: HashMap::from([("city".to_string(), 0.99)]), - model: "nli-model".to_string(), - }; - - let body_text = serde_json::to_string(&hallucatination_body).unwrap(); - - module - .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&body_text)) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "api_server"), - (":method", "POST"), - (":path", "/weather"), - (":authority", "api_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ]), - None, - None, - None, - ) - .returning(Some(6)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let body_text = String::from("test body"); - module - .call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&body_text)) - .expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status")) - .returning(Some("200")) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(chat_completions_request_body)) + // The actual call is not important in this test, we just need to grab the token_id .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) + // .expect_metric_increment("active_http_calls", 1) .expect_send_local_response( Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()), None, @@ -659,7 +336,7 @@ fn request_ratelimited() { None, ) .expect_metric_increment("ratelimited_rq", 1) - .execute_and_expect(ReturnType::None) + .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); } @@ -679,127 +356,49 @@ fn request_not_ratelimited() { .unwrap(); // Setup Filter - let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap(); - config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000; - let config_str = serde_json::to_string(&config).unwrap(); - - let filter_context = setup_filter(&mut module, &config_str); + let filter_context = setup_filter(&mut module, default_config()); // Setup HTTP Stream let http_context = 2; normal_flow(&mut module, filter_context, http_context); - let arch_fc_resp = ChatCompletionsResponse { - usage: Some(Usage { - completion_tokens: 0, - }), - choices: vec![Choice { - finish_reason: "test".to_string(), - index: 0, - message: Message { - role: "system".to_string(), - content: None, - tool_calls: Some(vec![ToolCall { - id: String::from("test"), - tool_type: ToolType::Function, - function: FunctionCallDetail { - name: String::from("weather_forecast"), - arguments: HashMap::from([( - String::from("city"), - Value::String(String::from("seattle")), - )]), - }, - }]), - model: None, - }, - }], - model: String::from("test"), - metadata: None, - }; + // give shorter body to avoid rate limiting + let chat_completions_request_body = "\ +{\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + }\ + ],\ + \"model\": \"gpt-4\"\ +}"; - let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); module - .call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&arch_fc_resp_str)) + .call_proxy_on_request_body( + http_context, + chat_completions_request_body.len() as i32, + true, + ) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(chat_completions_request_body)) + // The actual call is not important in this test, we just need to grab the token_id .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/hallucination"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), + // .expect_metric_increment("active_http_calls", 1) + .expect_send_local_response( + Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()), None, None, None, ) - .returning(Some(5)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - // hallucination should return that parameters were not halliucinated - // prompt: str - // parameters: dict - // model: str - - let hallucatination_body = HallucinationClassificationResponse { - params_scores: HashMap::from([("city".to_string(), 0.99)]), - model: "nli-model".to_string(), - }; - - let body_text = serde_json::to_string(&hallucatination_body).unwrap(); - - module - .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&body_text)) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "api_server"), - (":method", "POST"), - (":path", "/weather"), - (":authority", "api_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ]), - None, - None, - None, - ) - .returning(Some(6)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let body_text = String::from("test body"); - module - .call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&body_text)) - .expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status")) - .returning(Some("200")) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) - .execute_and_expect(ReturnType::None) + .expect_metric_increment("ratelimited_rq", 1) + .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); }