diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index ac6009f8..0c3ccf60 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -1,14 +1,7 @@ -use common::api::hallucination::HallucinationClassificationResponse; use common::api::open_ai::{ ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage, }; -use common::api::prompt_guard::PromptGuardResponse; -use common::api::zero_shot::ZeroShotClassificationResponse; use common::configuration::Configuration; -use common::embeddings::{ - create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, - Embedding, -}; use http::StatusCode; use proxy_wasm_test_framework::tester::{self, Tester}; use proxy_wasm_test_framework::types::{ @@ -83,13 +76,11 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_http_call( Some("arch_internal"), Some(vec![ - ("x-arch-upstream", "guard"), + ("x-arch-upstream", "model_server"), (":method", "POST"), - (":path", "/guard"), - (":authority", "guard"), + (":path", "/function_calling"), ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), + (":authority", "model_server"), ]), None, None, @@ -97,139 +88,11 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { ) .returning(Some(1)) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); - - let prompt_guard_response = PromptGuardResponse { - toxic_prob: None, - toxic_verdict: None, - jailbreak_prob: None, - jailbreak_verdict: None, - }; - let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap(); - module - .call_proxy_on_http_call_response( - http_context, - 1, - 0, - prompt_guard_response_buffer.len() as i32, - 0, - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&prompt_guard_response_buffer)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "embeddings"), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", "embeddings"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(2)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let embedding_response = CreateEmbeddingResponse { - data: vec![Embedding { - index: 0, - embedding: vec![], - object: embedding::Object::default(), - }], - model: String::from("test"), - object: create_embedding_response::Object::default(), - usage: Box::new(CreateEmbeddingResponseUsage::new(0, 0)), - }; - let embeddings_response_buffer = serde_json::to_string(&embedding_response).unwrap(); - module - .call_proxy_on_http_call_response( - http_context, - 2, - 0, - embeddings_response_buffer.len() as i32, - 0, - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&embeddings_response_buffer)) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Warn), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "zeroshot"), - (":method", "POST"), - (":path", "/zeroshot"), - (":authority", "zeroshot"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(3)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let zero_shot_response = ZeroShotClassificationResponse { - predicted_class: "weather_forecast".to_string(), - predicted_class_score: 0.1, - scores: HashMap::new(), - model: "test-model".to_string(), - }; - let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap(); - module - .call_proxy_on_http_call_response( - http_context, - 3, - 0, - zeroshot_intent_detection_buffer.len() as i32, - 0, - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&zeroshot_intent_detection_buffer)) - .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - (":method", "POST"), - ("x-arch-upstream", "arch_fc"), - (":path", "/v1/chat/completions"), - (":authority", "arch_fc"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "120000"), - ]), - None, - None, - None, - ) - .returning(Some(4)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); } fn setup_filter(module: &mut Tester, config: &str) -> i32 { @@ -248,69 +111,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { .execute_and_expect(ReturnType::Bool(true)) .unwrap(); - module - .call_proxy_on_tick(filter_context) - .expect_log(Some(LogLevel::Info), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "embeddings"), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", "embeddings"), - ("content-type", "application/json"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(101)) - .expect_metric_increment("active_http_calls", 1) - .expect_set_tick_period_millis(Some(5000)) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let embedding_response = CreateEmbeddingResponse { - data: vec![Embedding { - embedding: vec![], - index: 0, - object: embedding::Object::default(), - }], - model: String::from("test"), - object: create_embedding_response::Object::default(), - usage: Box::new(CreateEmbeddingResponseUsage { - prompt_tokens: 0, - total_tokens: 0, - }), - }; - let embedding_response_str = serde_json::to_string(&embedding_response).unwrap(); - module - .call_proxy_on_http_call_response( - filter_context, - 101, - 0, - embedding_response_str.len() as i32, - 0, - ) - .expect_log( - Some(LogLevel::Trace), - Some( - format!( - "filter_context: on_http_call_response called with token_id: {:?}", - 101 - ) - .as_str(), - ), - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&embedding_response_str)) - .expect_log(Some(LogLevel::Debug), None) - .execute_and_expect(ReturnType::None) - .unwrap(); - filter_context } @@ -435,6 +235,7 @@ fn prompt_gateway_successful_request_to_open_ai_chat_completions() { .returning(Some(chat_completions_request_body)) .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None) .expect_http_call(Some("arch_internal"), None, None, None, None) .returning(Some(4)) @@ -538,8 +339,8 @@ fn prompt_gateway_request_to_llm_gateway() { completion_tokens: 0, }), choices: vec![Choice { - finish_reason: "test".to_string(), - index: 0, + finish_reason: Some("test".to_string()), + index: Some(0), message: Message { role: "system".to_string(), content: None, @@ -564,55 +365,12 @@ fn prompt_gateway_request_to_llm_gateway() { let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); module - .call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&arch_fc_resp_str)) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Trace), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "hallucination"), - (":method", "POST"), - (":path", "/hallucination"), - (":authority", "hallucination"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(5)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - // hallucination should return that parameters were not halliucinated - // prompt: str - // parameters: dict - // model: str - - let hallucatination_body = HallucinationClassificationResponse { - params_scores: HashMap::from([("city".to_string(), 0.99)]), - model: "nli-model".to_string(), - }; - - let body_text = serde_json::to_string(&hallucatination_body).unwrap(); - - module - .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&body_text)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), @@ -628,14 +386,14 @@ fn prompt_gateway_request_to_llm_gateway() { None, None, ) - .returning(Some(6)) + .returning(Some(2)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) .unwrap(); let body_text = String::from("test body"); module - .call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) @@ -652,8 +410,8 @@ fn prompt_gateway_request_to_llm_gateway() { completion_tokens: 0, }), choices: vec![Choice { - finish_reason: "test".to_string(), - index: 0, + finish_reason: Some("test".to_string()), + index: Some(0), message: Message { role: "assistant".to_string(), content: Some("hello from fake llm gateway".to_string()),