diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 79dd99a7..a3188181 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -1074,7 +1074,36 @@ impl StreamContext { .prompt_guards .jailbreak_on_exception_message() .unwrap_or("refrain from discussing jailbreaking."); - warn!("jailbreak detected: {}", msg); + info!("jailbreak detected: {}", msg); + + let default_target_response_str = if self.streaming_response { + let chunks = vec![ + ChatCompletionStreamResponse::new( + None, + Some(ASSISTANT_ROLE.to_string()), + Some(ARCH_FC_MODEL_NAME.to_owned()), + None, + ), + ChatCompletionStreamResponse::new( + Some(msg.to_string()), + None, + Some(ARCH_FC_MODEL_NAME.to_owned()), + None, + ), + ]; + + to_server_events(chunks) + } else { + let chat_completion_response = ChatCompletionsResponse::new(msg.to_string()); + serde_json::to_string(&chat_completion_response).unwrap() + }; + + self.send_http_response( + StatusCode::OK.as_u16().into(), + vec![("Powered-By", "Katanemo")], + Some(default_target_response_str.as_bytes()), + ); + return self.send_server_error( ServerError::Jailbreak(String::from(msg)), Some(StatusCode::BAD_REQUEST), diff --git a/e2e_tests/test_prompt_gateway.py b/e2e_tests/test_prompt_gateway.py index 52801607..fd089941 100644 --- a/e2e_tests/test_prompt_gateway.py +++ b/e2e_tests/test_prompt_gateway.py @@ -286,9 +286,24 @@ def test_prompt_gateway_prompt_guard_jailbreak(stream): "stream": stream, } response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream) - assert response.status_code == 400 - response_json = response.text - assert ( - response_json - == "jailbreak detected: Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters." - ) + assert response.status_code == 200 + + if stream: + chunks = get_data_chunks(response, n=20) + assert len(chunks) == 2 + + response_json = json.loads(chunks[1]) + print(response_json) + choices = response_json.get("choices", []) + assert len(choices) > 0 + content = choices[0]["delta"]["content"] + assert ( + content + == "Looks like you're curious about my abilities, but I can only provide assistance for weather forecasting." + ) + else: + response_json = response.json() + assert ( + response_json.get("choices")[0]["message"]["content"] + == "Looks like you're curious about my abilities, but I can only provide assistance for weather forecasting." + )