mirror of
https://github.com/katanemo/plano.git
synced 2026-05-18 13:45:15 +02:00
[Kan-103] add support toxic/jailbreak model (#49)
* add toxic/jailbreak model * fix path loading model * fix syntax * fix bug,lint, format * fix bug * formatting * add parallel + chunking * fix bug * working version * fix onnnx name erorr * device * fix jailbreak config * fix syntax error * format * add requirement + cli download for dockerfile * add task * add skeleton change for envoy filter for prompt guard * fix hardware config * fix bug * add config changes * add gitignore * merge main * integrate arch-guard with filter * add hardware config * nothing * add hardware config feature * fix requirement * fix chat ui * fix onnx * fix lint * remove non intel cpu * remove onnx * working version * modify docker * fix guard time * add nvidia support * remove nvidia * add gpu * add gpu * add gpu support * add gpu support for compose * add gpu support for compose * add gpu support for compose * add gpu support for compose * add gpu support for compose * fix docker file * fix int test * correct gpu docker * upgrad python 10 * fix logits to be gpu compatible * default to cpu dockerfile * resolve comments * fix lint + unused parameters * fix * remove eetq install for cpu * remove deploy gpu --------- Co-authored-by: Adil Hafeez <adil@katanemo.com>
This commit is contained in:
parent
80c554ce1a
commit
79b1c5415f
18 changed files with 1622 additions and 191 deletions
|
|
@ -9,7 +9,7 @@ use open_message_format_embeddings::models::{
|
|||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::EmbeddingType;
|
||||
use public_types::configuration::{Configuration, Overrides, PromptTarget};
|
||||
use public_types::configuration::{Configuration, Overrides, PromptGuards, PromptTarget};
|
||||
use serde_json::to_string;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
|
@ -47,6 +47,7 @@ pub struct FilterContext {
|
|||
config: Option<Configuration>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
}
|
||||
|
||||
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
|
||||
|
|
@ -65,6 +66,7 @@ impl FilterContext {
|
|||
metrics: Rc::new(WasmMetrics::new()),
|
||||
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(Some(PromptGuards::default())),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -238,6 +240,14 @@ impl RootContext for FilterContext {
|
|||
{
|
||||
ratelimit::ratelimits(Some(std::mem::take(ratelimits_config)));
|
||||
}
|
||||
|
||||
if let Some(prompt_guards) = self
|
||||
.config
|
||||
.as_mut()
|
||||
.and_then(|config| config.prompt_guards.as_mut())
|
||||
{
|
||||
self.prompt_guards = Rc::new(Some(std::mem::take(prompt_guards)));
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
|
@ -247,6 +257,7 @@ impl RootContext for FilterContext {
|
|||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
Rc::clone(&self.prompt_targets),
|
||||
Rc::clone(&self.prompt_guards),
|
||||
Rc::clone(&self.overrides),
|
||||
)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,10 +21,11 @@ use public_types::common_types::open_ai::{
|
|||
StreamOptions,
|
||||
};
|
||||
use public_types::common_types::{
|
||||
BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition,
|
||||
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||
BoltFCToolsCall, EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
|
||||
ToolParameter, ToolParameters, ToolsDefinition, ZeroShotClassificationRequest,
|
||||
ZeroShotClassificationResponse,
|
||||
};
|
||||
use public_types::configuration::{Overrides, PromptTarget, PromptType};
|
||||
use public_types::configuration::{Overrides, PromptGuards, PromptTarget, PromptType};
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
|
|
@ -36,6 +37,7 @@ enum ResponseHandlerType {
|
|||
FunctionResolver,
|
||||
FunctionCall,
|
||||
ZeroShotIntent,
|
||||
ArchGuard,
|
||||
}
|
||||
|
||||
pub struct CallContext {
|
||||
|
|
@ -57,6 +59,7 @@ pub struct StreamContext {
|
|||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
chat_completions_request: bool,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -64,6 +67,7 @@ impl StreamContext {
|
|||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
) -> Self {
|
||||
StreamContext {
|
||||
|
|
@ -76,6 +80,7 @@ impl StreamContext {
|
|||
streaming_response: false,
|
||||
response_tokens: 0,
|
||||
chat_completions_request: false,
|
||||
prompt_guards,
|
||||
overrides,
|
||||
}
|
||||
}
|
||||
|
|
@ -640,6 +645,108 @@ impl StreamContext {
|
|||
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
|
||||
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
|
||||
debug!("response received for arch guard");
|
||||
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
||||
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
|
||||
|
||||
if prompt_guard_resp.jailbreak_verdict.is_some()
|
||||
&& prompt_guard_resp.jailbreak_verdict.unwrap()
|
||||
{
|
||||
let default_err = "Jailbreak detected. Please refrain from discussing jailbreaking.";
|
||||
let error_msg = match self.prompt_guards.as_ref() {
|
||||
Some(prompt_guards) => match prompt_guards.input_guards.jailbreak.as_ref() {
|
||||
Some(jailbreak) => match jailbreak.on_exception_message.as_ref() {
|
||||
Some(error_msg) => error_msg,
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
};
|
||||
|
||||
return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
|
||||
if prompt_guard_resp.toxic_verdict.is_some() && prompt_guard_resp.toxic_verdict.unwrap() {
|
||||
let default_err = "Toxicity detected. Please refrain from using toxic language.";
|
||||
let error_msg = match self.prompt_guards.as_ref() {
|
||||
Some(prompt_guards) => match prompt_guards.input_guards.toxicity.as_ref() {
|
||||
Some(toxicity) => match toxicity.on_exception_message.as_ref() {
|
||||
Some(error_msg) => error_msg,
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
};
|
||||
|
||||
return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
|
||||
self.get_embeddings(callout_context);
|
||||
}
|
||||
|
||||
fn get_embeddings(&mut self, callout_context: CallContext) {
|
||||
let user_message = callout_context.user_message.unwrap();
|
||||
let get_embeddings_input = CreateEmbeddingRequest {
|
||||
// Need to clone into input because user_message is used below.
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing embeddings input: {}", error);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
MODEL_SERVER_NAME,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", MODEL_SERVER_NAME),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"dispatched HTTP call to embedding server token_id={}",
|
||||
token_id
|
||||
);
|
||||
|
||||
let call_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::GetEmbeddings,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
request_body: callout_context.request_body,
|
||||
similarity_scores: None,
|
||||
};
|
||||
if self.callouts.insert(token_id, call_context).is_some() {
|
||||
panic!(
|
||||
"duplicate token_id={} in embedding server requests",
|
||||
token_id
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
|
||||
|
|
@ -711,16 +818,51 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let get_embeddings_input = CreateEmbeddingRequest {
|
||||
// Need to clone into input because user_message is used below.
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
let prompt_guards = match self.prompt_guards.as_ref() {
|
||||
Some(prompt_guards) => {
|
||||
debug!("prompt guards: {:?}", prompt_guards);
|
||||
prompt_guards
|
||||
}
|
||||
None => {
|
||||
let callout_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
};
|
||||
self.get_embeddings(callout_context);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
|
||||
let prompt_guard_task = match (
|
||||
prompt_guards.input_guards.toxicity.is_some(),
|
||||
prompt_guards.input_guards.jailbreak.is_some(),
|
||||
) {
|
||||
(true, true) => PromptGuardTask::Both,
|
||||
(true, false) => PromptGuardTask::Toxicity,
|
||||
(false, true) => PromptGuardTask::Jailbreak,
|
||||
(false, false) => {
|
||||
info!("Input guards set but no prompt guards were found");
|
||||
let callout_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
};
|
||||
self.get_embeddings(callout_context);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
let get_prompt_guards_request = PromptGuardRequest {
|
||||
input: user_message.clone(),
|
||||
task: prompt_guard_task,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing embeddings input: {}", error);
|
||||
|
|
@ -731,7 +873,7 @@ impl HttpContext for StreamContext {
|
|||
MODEL_SERVER_NAME,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":path", "/guard"),
|
||||
(":authority", MODEL_SERVER_NAME),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
|
|
@ -749,13 +891,11 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"dispatched HTTP call to embedding server token_id={}",
|
||||
token_id
|
||||
);
|
||||
|
||||
debug!("dispatched HTTP call to bolt_guard token_id={}", token_id);
|
||||
|
||||
let call_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::GetEmbeddings,
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
|
|
@ -876,15 +1016,16 @@ impl Context for StreamContext {
|
|||
ResponseHandlerType::GetEmbeddings => {
|
||||
self.embeddings_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::ZeroShotIntent => {
|
||||
self.zero_shot_intent_detection_resp_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::FunctionResolver => {
|
||||
self.function_resolver_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::FunctionCall => {
|
||||
self.function_call_response_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::ZeroShotIntent => {
|
||||
self.zero_shot_intent_detection_resp_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
|
||||
}
|
||||
} else {
|
||||
self.send_server_error(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue