[Kan-103] add support toxic/jailbreak model (#49)

* add toxic/jailbreak model

* fix path loading model

* fix syntax

* fix bug,lint, format

* fix bug

* formatting

* add parallel + chunking

* fix bug

* working version

* fix onnnx name erorr

* device

* fix jailbreak config

* fix syntax error

* format

* add requirement + cli download for dockerfile

* add task

* add skeleton change for envoy filter for prompt guard

* fix hardware config

* fix bug

* add config changes

* add gitignore

* merge main

* integrate arch-guard with filter

* add hardware config

* nothing

* add hardware config feature

* fix requirement

* fix chat ui

* fix onnx

* fix lint

* remove non intel cpu

* remove onnx

* working version

* modify docker

* fix guard time

* add nvidia support

* remove nvidia

* add gpu

* add gpu

* add gpu support

* add gpu support for compose

* add gpu support for compose

* add gpu support for compose

* add gpu support for compose

* add gpu support for compose

* fix docker file

* fix int test

* correct gpu docker

* upgrad python 10

* fix logits to be gpu compatible

* default to cpu dockerfile

* resolve comments

* fix lint + unused parameters

* fix

* remove eetq install for cpu

* remove deploy gpu

---------

Co-authored-by: Adil Hafeez <adil@katanemo.com>
This commit is contained in:
Co Tran 2024-09-23 12:07:31 -07:00 committed by GitHub
parent 80c554ce1a
commit 79b1c5415f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 1622 additions and 191 deletions

View file

@ -9,7 +9,7 @@ use open_message_format_embeddings::models::{
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::EmbeddingType;
use public_types::configuration::{Configuration, Overrides, PromptTarget};
use public_types::configuration::{Configuration, Overrides, PromptGuards, PromptTarget};
use serde_json::to_string;
use std::collections::HashMap;
use std::rc::Rc;
@ -47,6 +47,7 @@ pub struct FilterContext {
config: Option<Configuration>,
overrides: Rc<Option<Overrides>>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
prompt_guards: Rc<Option<PromptGuards>>,
}
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
@ -65,6 +66,7 @@ impl FilterContext {
metrics: Rc::new(WasmMetrics::new()),
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
overrides: Rc::new(None),
prompt_guards: Rc::new(Some(PromptGuards::default())),
}
}
@ -238,6 +240,14 @@ impl RootContext for FilterContext {
{
ratelimit::ratelimits(Some(std::mem::take(ratelimits_config)));
}
if let Some(prompt_guards) = self
.config
.as_mut()
.and_then(|config| config.prompt_guards.as_mut())
{
self.prompt_guards = Rc::new(Some(std::mem::take(prompt_guards)));
}
}
true
}
@ -247,6 +257,7 @@ impl RootContext for FilterContext {
context_id,
Rc::clone(&self.metrics),
Rc::clone(&self.prompt_targets),
Rc::clone(&self.prompt_guards),
Rc::clone(&self.overrides),
)))
}

View file

@ -21,10 +21,11 @@ use public_types::common_types::open_ai::{
StreamOptions,
};
use public_types::common_types::{
BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition,
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
BoltFCToolsCall, EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
ToolParameter, ToolParameters, ToolsDefinition, ZeroShotClassificationRequest,
ZeroShotClassificationResponse,
};
use public_types::configuration::{Overrides, PromptTarget, PromptType};
use public_types::configuration::{Overrides, PromptGuards, PromptTarget, PromptType};
use std::collections::HashMap;
use std::num::NonZero;
use std::rc::Rc;
@ -36,6 +37,7 @@ enum ResponseHandlerType {
FunctionResolver,
FunctionCall,
ZeroShotIntent,
ArchGuard,
}
pub struct CallContext {
@ -57,6 +59,7 @@ pub struct StreamContext {
streaming_response: bool,
response_tokens: usize,
chat_completions_request: bool,
prompt_guards: Rc<Option<PromptGuards>>,
}
impl StreamContext {
@ -64,6 +67,7 @@ impl StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
prompt_guards: Rc<Option<PromptGuards>>,
overrides: Rc<Option<Overrides>>,
) -> Self {
StreamContext {
@ -76,6 +80,7 @@ impl StreamContext {
streaming_response: false,
response_tokens: 0,
chat_completions_request: false,
prompt_guards,
overrides,
}
}
@ -640,6 +645,108 @@ impl StreamContext {
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
self.resume_http_request();
}
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
debug!("response received for arch guard");
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
if prompt_guard_resp.jailbreak_verdict.is_some()
&& prompt_guard_resp.jailbreak_verdict.unwrap()
{
let default_err = "Jailbreak detected. Please refrain from discussing jailbreaking.";
let error_msg = match self.prompt_guards.as_ref() {
Some(prompt_guards) => match prompt_guards.input_guards.jailbreak.as_ref() {
Some(jailbreak) => match jailbreak.on_exception_message.as_ref() {
Some(error_msg) => error_msg,
None => default_err,
},
None => default_err,
},
None => default_err,
};
return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST));
}
if prompt_guard_resp.toxic_verdict.is_some() && prompt_guard_resp.toxic_verdict.unwrap() {
let default_err = "Toxicity detected. Please refrain from using toxic language.";
let error_msg = match self.prompt_guards.as_ref() {
Some(prompt_guards) => match prompt_guards.input_guards.toxicity.as_ref() {
Some(toxicity) => match toxicity.on_exception_message.as_ref() {
Some(error_msg) => error_msg,
None => default_err,
},
None => default_err,
},
None => default_err,
};
return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST));
}
self.get_embeddings(callout_context);
}
fn get_embeddings(&mut self, callout_context: CallContext) {
let user_message = callout_context.user_message.unwrap();
let get_embeddings_input = CreateEmbeddingRequest {
// Need to clone into input because user_message is used below.
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing embeddings input: {}", error);
}
};
let token_id = match self.dispatch_http_call(
MODEL_SERVER_NAME,
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", MODEL_SERVER_NAME),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!(
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
e
);
}
};
debug!(
"dispatched HTTP call to embedding server token_id={}",
token_id
);
let call_context = CallContext {
response_handler_type: ResponseHandlerType::GetEmbeddings,
user_message: Some(user_message),
prompt_target_name: None,
request_body: callout_context.request_body,
similarity_scores: None,
};
if self.callouts.insert(token_id, call_context).is_some() {
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
}
}
}
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
@ -711,16 +818,51 @@ impl HttpContext for StreamContext {
}
};
let get_embeddings_input = CreateEmbeddingRequest {
// Need to clone into input because user_message is used below.
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
let prompt_guards = match self.prompt_guards.as_ref() {
Some(prompt_guards) => {
debug!("prompt guards: {:?}", prompt_guards);
prompt_guards
}
None => {
let callout_context = CallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: Some(user_message),
prompt_target_name: None,
request_body: deserialized_body,
similarity_scores: None,
};
self.get_embeddings(callout_context);
return Action::Pause;
}
};
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
let prompt_guard_task = match (
prompt_guards.input_guards.toxicity.is_some(),
prompt_guards.input_guards.jailbreak.is_some(),
) {
(true, true) => PromptGuardTask::Both,
(true, false) => PromptGuardTask::Toxicity,
(false, true) => PromptGuardTask::Jailbreak,
(false, false) => {
info!("Input guards set but no prompt guards were found");
let callout_context = CallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: Some(user_message),
prompt_target_name: None,
request_body: deserialized_body,
similarity_scores: None,
};
self.get_embeddings(callout_context);
return Action::Pause;
}
};
let get_prompt_guards_request = PromptGuardRequest {
input: user_message.clone(),
task: prompt_guard_task,
};
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing embeddings input: {}", error);
@ -731,7 +873,7 @@ impl HttpContext for StreamContext {
MODEL_SERVER_NAME,
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":path", "/guard"),
(":authority", MODEL_SERVER_NAME),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
@ -749,13 +891,11 @@ impl HttpContext for StreamContext {
);
}
};
debug!(
"dispatched HTTP call to embedding server token_id={}",
token_id
);
debug!("dispatched HTTP call to bolt_guard token_id={}", token_id);
let call_context = CallContext {
response_handler_type: ResponseHandlerType::GetEmbeddings,
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: Some(user_message),
prompt_target_name: None,
request_body: deserialized_body,
@ -876,15 +1016,16 @@ impl Context for StreamContext {
ResponseHandlerType::GetEmbeddings => {
self.embeddings_handler(body, callout_context)
}
ResponseHandlerType::ZeroShotIntent => {
self.zero_shot_intent_detection_resp_handler(body, callout_context)
}
ResponseHandlerType::FunctionResolver => {
self.function_resolver_handler(body, callout_context)
}
ResponseHandlerType::FunctionCall => {
self.function_call_response_handler(body, callout_context)
}
ResponseHandlerType::ZeroShotIntent => {
self.zero_shot_intent_detection_resp_handler(body, callout_context)
}
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
}
} else {
self.send_server_error(