mirror of
https://github.com/katanemo/plano.git
synced 2026-06-05 14:45:15 +02:00
Remove optional PromptGuards from Stream Context (#113)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
8ea917aae5
commit
af018e5fd8
3 changed files with 150 additions and 61 deletions
|
|
@ -47,8 +47,7 @@ pub struct FilterContext {
|
|||
callouts: HashMap<u32, CallContext>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
// This should be Option<Rc<PromptGuards>>, because StreamContext::new() should get an Rc<PromptGuards> not Option<Rc<PromptGuards>>.
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
llm_providers: Option<Rc<LlmProviders>>,
|
||||
}
|
||||
|
||||
|
|
@ -67,7 +66,7 @@ impl FilterContext {
|
|||
metrics: Rc::new(WasmMetrics::new()),
|
||||
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(Some(PromptGuards::default())),
|
||||
prompt_guards: Rc::new(PromptGuards::default()),
|
||||
llm_providers: None,
|
||||
}
|
||||
}
|
||||
|
|
@ -242,7 +241,7 @@ impl RootContext for FilterContext {
|
|||
ratelimit::ratelimits(config.ratelimits);
|
||||
|
||||
if let Some(prompt_guards) = config.prompt_guards {
|
||||
self.prompt_guards = Rc::new(Some(prompt_guards))
|
||||
self.prompt_guards = Rc::new(prompt_guards)
|
||||
}
|
||||
|
||||
match config.llm_providers.try_into() {
|
||||
|
|
|
|||
|
|
@ -48,8 +48,8 @@ pub struct CallContext {
|
|||
prompt_target_name: Option<String>,
|
||||
request_body: ChatCompletionsRequest,
|
||||
similarity_scores: Option<Vec<(String, f64)>>,
|
||||
up_stream_cluster: Option<String>,
|
||||
up_stream_cluster_path: Option<String>,
|
||||
upstream_cluster: Option<String>,
|
||||
upstream_cluster_path: Option<String>,
|
||||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
|
|
@ -62,9 +62,9 @@ pub struct StreamContext {
|
|||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
chat_completions_request: bool,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
llm_provider: Option<Rc<LlmProvider>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -72,7 +72,7 @@ impl StreamContext {
|
|||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
) -> Self {
|
||||
|
|
@ -615,8 +615,8 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
callout_context.up_stream_cluster = Some(endpoint.name);
|
||||
callout_context.up_stream_cluster_path = Some(path);
|
||||
callout_context.upstream_cluster = Some(endpoint.name);
|
||||
callout_context.upstream_cluster_path = Some(path);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
|
|
@ -630,8 +630,8 @@ impl StreamContext {
|
|||
if http_status.1 != StatusCode::OK.as_str() {
|
||||
let error_msg = format!(
|
||||
"Error in function call response: cluster: {}, path: {}, status code: {}",
|
||||
callout_context.up_stream_cluster.unwrap(),
|
||||
callout_context.up_stream_cluster_path.unwrap(),
|
||||
callout_context.upstream_cluster.unwrap(),
|
||||
callout_context.upstream_cluster_path.unwrap(),
|
||||
http_status.1
|
||||
);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
|
|
@ -741,9 +741,9 @@ impl StreamContext {
|
|||
|
||||
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
|
||||
//TODO: handle other scenarios like forward to error target
|
||||
let msg = (*self.prompt_guards)
|
||||
.as_ref()
|
||||
.and_then(|pg| pg.jailbreak_on_exception_message())
|
||||
let msg = self
|
||||
.prompt_guards
|
||||
.jailbreak_on_exception_message()
|
||||
.unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking.");
|
||||
return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
|
|
@ -801,8 +801,8 @@ impl StreamContext {
|
|||
prompt_target_name: None,
|
||||
request_body: callout_context.request_body,
|
||||
similarity_scores: None,
|
||||
up_stream_cluster: None,
|
||||
up_stream_cluster_path: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
};
|
||||
if self.callouts.insert(token_id, call_context).is_some() {
|
||||
panic!(
|
||||
|
|
@ -810,6 +810,8 @@ impl StreamContext {
|
|||
token_id
|
||||
)
|
||||
}
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
}
|
||||
|
||||
fn default_target_handler(&self, body: Vec<u8>, callout_context: CallContext) {
|
||||
|
|
@ -971,39 +973,20 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let prompt_guards = match self.prompt_guards.as_ref() {
|
||||
Some(prompt_guards) => {
|
||||
debug!("prompt guards: {:?}", prompt_guards);
|
||||
prompt_guards
|
||||
}
|
||||
None => {
|
||||
let callout_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
up_stream_cluster: None,
|
||||
up_stream_cluster_path: None,
|
||||
};
|
||||
self.get_embeddings(callout_context);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
let prompt_guard_jailbreak_task = prompt_guards
|
||||
let prompt_guard_jailbreak_task = self
|
||||
.prompt_guards
|
||||
.input_guards
|
||||
.contains_key(&public_types::configuration::GuardType::Jailbreak);
|
||||
if !prompt_guard_jailbreak_task {
|
||||
info!("Input guards set but no prompt guards were found");
|
||||
debug!("Missing input guard. Making inline call to retrieve");
|
||||
let callout_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
up_stream_cluster: None,
|
||||
up_stream_cluster_path: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
};
|
||||
self.get_embeddings(callout_context);
|
||||
return Action::Pause;
|
||||
|
|
@ -1056,8 +1039,8 @@ impl HttpContext for StreamContext {
|
|||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
up_stream_cluster: None,
|
||||
up_stream_cluster_path: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
};
|
||||
if self.callouts.insert(token_id, call_context).is_some() {
|
||||
panic!(
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use proxy_wasm_test_framework::types::{
|
|||
};
|
||||
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
|
||||
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
|
||||
use public_types::common_types::PromptGuardResponse;
|
||||
use public_types::embeddings::embedding::Object;
|
||||
use public_types::embeddings::{
|
||||
create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding,
|
||||
|
|
@ -91,14 +92,66 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_http_call(Some("model_server"), None, None, None, None)
|
||||
.expect_http_call(
|
||||
Some("model_server"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/guard"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(1))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
|
||||
let prompt_guard_response = PromptGuardResponse {
|
||||
toxic_prob: None,
|
||||
toxic_verdict: None,
|
||||
jailbreak_prob: None,
|
||||
jailbreak_verdict: None,
|
||||
};
|
||||
let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
1,
|
||||
0,
|
||||
prompt_guard_response_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&prompt_guard_response_buffer))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("model_server"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(2))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let embedding_response = CreateEmbeddingResponse {
|
||||
data: vec![Embedding {
|
||||
index: 0,
|
||||
|
|
@ -113,7 +166,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
1,
|
||||
2,
|
||||
0,
|
||||
embeddings_response_buffer.len() as i32,
|
||||
0,
|
||||
|
|
@ -123,8 +176,21 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
.returning(Some(&embeddings_response_buffer))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("model_server"), None, None, None, None)
|
||||
.returning(Some(2))
|
||||
.expect_http_call(
|
||||
Some("model_server"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/zeroshot"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(3))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
|
|
@ -140,7 +206,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
2,
|
||||
3,
|
||||
0,
|
||||
zeroshot_intent_detection_buffer.len() as i32,
|
||||
0,
|
||||
|
|
@ -151,8 +217,21 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_http_call(Some("arch_fc"), None, None, None, None)
|
||||
.returning(Some(3))
|
||||
.expect_http_call(
|
||||
Some("arch_fc"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/v1/chat/completions"),
|
||||
(":authority", "arch_fc"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "120000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(4))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
|
|
@ -189,8 +268,13 @@ overrides:
|
|||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
||||
prompt_targets:
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception:
|
||||
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
|
||||
|
||||
prompt_targets:
|
||||
- name: weather_forecast
|
||||
description: This function provides realtime weather forecast information for a given city.
|
||||
parameters:
|
||||
|
|
@ -308,7 +392,6 @@ fn successful_request_to_open_ai_chat_completions() {
|
|||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_http_call(Some("model_server"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
|
|
@ -459,7 +542,7 @@ fn request_ratelimited() {
|
|||
|
||||
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 3, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&arch_fc_resp_str))
|
||||
|
|
@ -470,15 +553,27 @@ fn request_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("api_server"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_http_call(
|
||||
Some("api_server"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/weather"),
|
||||
(":authority", "api_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
|
|
@ -573,7 +668,7 @@ fn request_not_ratelimited() {
|
|||
|
||||
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 3, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&arch_fc_resp_str))
|
||||
|
|
@ -584,15 +679,27 @@ fn request_not_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("api_server"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_http_call(
|
||||
Some("api_server"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/weather"),
|
||||
(":authority", "api_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue