mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 21:02:56 +02:00
Remove optional PromptGuards from Stream Context (#113)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
8ea917aae5
commit
af018e5fd8
3 changed files with 150 additions and 61 deletions
|
|
@ -48,8 +48,8 @@ pub struct CallContext {
|
|||
prompt_target_name: Option<String>,
|
||||
request_body: ChatCompletionsRequest,
|
||||
similarity_scores: Option<Vec<(String, f64)>>,
|
||||
up_stream_cluster: Option<String>,
|
||||
up_stream_cluster_path: Option<String>,
|
||||
upstream_cluster: Option<String>,
|
||||
upstream_cluster_path: Option<String>,
|
||||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
|
|
@ -62,9 +62,9 @@ pub struct StreamContext {
|
|||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
chat_completions_request: bool,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
llm_provider: Option<Rc<LlmProvider>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -72,7 +72,7 @@ impl StreamContext {
|
|||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
prompt_guards: Rc<Option<PromptGuards>>,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
) -> Self {
|
||||
|
|
@ -615,8 +615,8 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
callout_context.up_stream_cluster = Some(endpoint.name);
|
||||
callout_context.up_stream_cluster_path = Some(path);
|
||||
callout_context.upstream_cluster = Some(endpoint.name);
|
||||
callout_context.upstream_cluster_path = Some(path);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
|
|
@ -630,8 +630,8 @@ impl StreamContext {
|
|||
if http_status.1 != StatusCode::OK.as_str() {
|
||||
let error_msg = format!(
|
||||
"Error in function call response: cluster: {}, path: {}, status code: {}",
|
||||
callout_context.up_stream_cluster.unwrap(),
|
||||
callout_context.up_stream_cluster_path.unwrap(),
|
||||
callout_context.upstream_cluster.unwrap(),
|
||||
callout_context.upstream_cluster_path.unwrap(),
|
||||
http_status.1
|
||||
);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
|
|
@ -741,9 +741,9 @@ impl StreamContext {
|
|||
|
||||
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
|
||||
//TODO: handle other scenarios like forward to error target
|
||||
let msg = (*self.prompt_guards)
|
||||
.as_ref()
|
||||
.and_then(|pg| pg.jailbreak_on_exception_message())
|
||||
let msg = self
|
||||
.prompt_guards
|
||||
.jailbreak_on_exception_message()
|
||||
.unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking.");
|
||||
return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
|
|
@ -801,8 +801,8 @@ impl StreamContext {
|
|||
prompt_target_name: None,
|
||||
request_body: callout_context.request_body,
|
||||
similarity_scores: None,
|
||||
up_stream_cluster: None,
|
||||
up_stream_cluster_path: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
};
|
||||
if self.callouts.insert(token_id, call_context).is_some() {
|
||||
panic!(
|
||||
|
|
@ -810,6 +810,8 @@ impl StreamContext {
|
|||
token_id
|
||||
)
|
||||
}
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
}
|
||||
|
||||
fn default_target_handler(&self, body: Vec<u8>, callout_context: CallContext) {
|
||||
|
|
@ -971,39 +973,20 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let prompt_guards = match self.prompt_guards.as_ref() {
|
||||
Some(prompt_guards) => {
|
||||
debug!("prompt guards: {:?}", prompt_guards);
|
||||
prompt_guards
|
||||
}
|
||||
None => {
|
||||
let callout_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
up_stream_cluster: None,
|
||||
up_stream_cluster_path: None,
|
||||
};
|
||||
self.get_embeddings(callout_context);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
let prompt_guard_jailbreak_task = prompt_guards
|
||||
let prompt_guard_jailbreak_task = self
|
||||
.prompt_guards
|
||||
.input_guards
|
||||
.contains_key(&public_types::configuration::GuardType::Jailbreak);
|
||||
if !prompt_guard_jailbreak_task {
|
||||
info!("Input guards set but no prompt guards were found");
|
||||
debug!("Missing input guard. Making inline call to retrieve");
|
||||
let callout_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
up_stream_cluster: None,
|
||||
up_stream_cluster_path: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
};
|
||||
self.get_embeddings(callout_context);
|
||||
return Action::Pause;
|
||||
|
|
@ -1056,8 +1039,8 @@ impl HttpContext for StreamContext {
|
|||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
up_stream_cluster: None,
|
||||
up_stream_cluster_path: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
};
|
||||
if self.callouts.insert(token_id, call_context).is_some() {
|
||||
panic!(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue