diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 6fbed26a..c5d23297 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -456,3 +456,252 @@ fn prompt_gateway_request_to_llm_gateway() { .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); } + +#[test] +#[serial] +fn prompt_gateway_request_no_intent_match() { + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); + + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); + + // Setup Filter + let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap(); + config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000; + let config_str = serde_json::to_string(&config).unwrap(); + + let filter_context = setup_filter(&mut module, &config_str); + + // Setup HTTP Stream + let http_context = 2; + + normal_flow(&mut module, filter_context, http_context); + + let arch_fc_resp = ChatCompletionsResponse { + usage: Some(Usage { + completion_tokens: 0, + }), + choices: vec![Choice { + finish_reason: Some("test".to_string()), + index: Some(0), + message: Message { + role: "system".to_string(), + content: None, + tool_calls: Some(vec![ToolCall { + id: String::from("test"), + tool_type: ToolType::Function, + function: FunctionCallDetail { + name: String::from("weather_forecast"), + arguments: HashMap::from([( + String::from("city"), + Value::String(String::from("seattle")), + )]), + }, + }]), + model: None, + tool_call_id: None, + }, + }], + model: String::from("test"), + metadata: None, + }; + + let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); + module + .call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&arch_fc_resp_str)) + .expect_log(Some(LogLevel::Warn), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Debug), Some("intent not matched")) + .expect_log( + Some(LogLevel::Debug), + Some("no default prompt target found, forwarding request to upstream llm"), + ) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) + .execute_and_expect(ReturnType::None) + .unwrap(); +} + +fn arch_config_default_target() -> &'static str { + r#" +version: "0.1-beta" + +listener: + address: 0.0.0.0 + port: 10000 + message_format: huggingface + connect_timeout: 0.005s + +endpoints: + api_server: + endpoint: api_server:80 + connect_timeout: 0.005s + +llm_providers: + - name: open-ai-gpt-4 + provider_interface: openai + access_key: secret_key + model: gpt-4 + default: true + +overrides: + # confidence threshold for prompt target intent matching + prompt_target_intent_matching_threshold: 0.0 + +system_prompt: | + You are a helpful assistant. + +prompt_guards: + input_guards: + jailbreak: + on_exception: + message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters." + +prompt_targets: + - name: weather_forecast + description: This function provides realtime weather forecast information for a given city. + parameters: + - name: city + required: true + description: The city for which the weather forecast is requested. + - name: days + description: The number of days for which the weather forecast is requested. + - name: units + description: The units in which the weather forecast is requested. + endpoint: + name: api_server + path: /weather + http_method: POST + system_prompt: | + You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries: + - Use farenheight for temperature + - Use miles per hour for wind speed + + - name: default_target + default: true + description: This is the default target for all unmatched prompts. + endpoint: + name: weather_forecast_service + path: /default_target + http_method: POST + system_prompt: | + You are a helpful assistant! Summarize the user's request and provide a helpful response. + # if it is set to false arch will send response that it received from this prompt target to the user + # if true arch will forward the response to the default LLM + auto_llm_dispatch_on_response: false + +ratelimits: + - model: gpt-4 + selector: + key: selector-key + value: selector-value + limit: + tokens: 1 + unit: minute +"# +} + +#[test] +#[serial] +fn prompt_gateway_request_no_intent_match_default_target() { + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); + + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); + + // Setup Filter + let mut config: Configuration = serde_yaml::from_str(arch_config_default_target()).unwrap(); + config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000; + let config_str = serde_json::to_string(&config).unwrap(); + + let filter_context = setup_filter(&mut module, &config_str); + + // Setup HTTP Stream + let http_context = 2; + + normal_flow(&mut module, filter_context, http_context); + + let arch_fc_resp = ChatCompletionsResponse { + usage: Some(Usage { + completion_tokens: 0, + }), + choices: vec![Choice { + finish_reason: Some("test".to_string()), + index: Some(0), + message: Message { + role: "system".to_string(), + content: None, + tool_calls: Some(vec![ToolCall { + id: String::from("test"), + tool_type: ToolType::Function, + function: FunctionCallDetail { + name: String::from("weather_forecast"), + arguments: HashMap::from([( + String::from("city"), + Value::String(String::from("seattle")), + )]), + }, + }]), + model: None, + tool_call_id: None, + }, + }], + model: String::from("test"), + metadata: None, + }; + + let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); + module + .call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&arch_fc_resp_str)) + .expect_log(Some(LogLevel::Warn), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Debug), Some("intent not matched")) + .expect_log( + Some(LogLevel::Debug), + Some("default prompt target found, forwarding request to default prompt target"), + ) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_http_call( + Some("arch_internal"), + Some(vec![ + (":method", "POST"), + ("x-arch-upstream", "weather_forecast_service"), + (":path", "/default_target"), + (":authority", "weather_forecast_service"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "30000"), + ]), + None, + None, + Some(5000), + ) + .returning(Some(2)) + .expect_metric_increment("active_http_calls", 1) + .execute_and_expect(ReturnType::None) + .unwrap(); +}