diff --git a/crates/common/src/errors.rs b/crates/common/src/errors.rs index fd634915..27b0341e 100644 --- a/crates/common/src/errors.rs +++ b/crates/common/src/errors.rs @@ -22,11 +22,12 @@ pub enum ServerError { Serialization(serde_json::Error), #[error("{0}")] LogicError(String), - #[error("upstream error response authority={authority}, path={path}, status={status}")] + #[error("upstream application error host={host}, path={path}, status={status}, body={body}")] Upstream { - authority: String, + host: String, path: String, status: String, + body: String, }, #[error("jailbreak detected: {0}")] Jailbreak(String), diff --git a/crates/common/src/http.rs b/crates/common/src/http.rs index 842818e2..f5d66a65 100644 --- a/crates/common/src/http.rs +++ b/crates/common/src/http.rs @@ -1,4 +1,7 @@ -use crate::{errors::ClientError, stats::{Gauge, IncrementingMetric}}; +use crate::{ + errors::ClientError, + stats::{Gauge, IncrementingMetric}, +}; use derivative::Derivative; use log::debug; use proxy_wasm::{traits::Context, types::Status}; diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index c23443ca..f2c95bc5 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -4,10 +4,10 @@ pub mod common_types; pub mod configuration; pub mod consts; pub mod embeddings; +pub mod errors; pub mod http; pub mod llm_providers; pub mod ratelimit; pub mod routing; pub mod stats; pub mod tokenizer; -pub mod errors; diff --git a/crates/prompt_gateway/src/context.rs b/crates/prompt_gateway/src/context.rs new file mode 100644 index 00000000..f33d6238 --- /dev/null +++ b/crates/prompt_gateway/src/context.rs @@ -0,0 +1,103 @@ +use common::errors::ServerError; +use common::stats::IncrementingMetric; +use proxy_wasm::traits::Context; + +use crate::stream_context::{ResponseHandlerType, StreamContext}; + +impl Context for StreamContext { + fn on_http_call_response( + &mut self, + token_id: u32, + _num_headers: usize, + body_size: usize, + _num_trailers: usize, + ) { + let callout_context = self + .callouts + .get_mut() + .remove(&token_id) + .expect("invalid token_id"); + self.metrics.active_http_calls.increment(-1); + + /* + state transition + + graph LR + + on_http_request_body --> prompt received + prompt received --> get embeddings & arch guard + arch guard --> get embeddings + get embeddings --> zeroshot intent + + ┌──────────────────────┐ ┌─────────────────┐ ┌────────────────┐ ┌─────────────────┐ + │ │ │ │ │ │ │ │ + │ on_http_request_body ├──►│ prompt received ├──►│ get embeddings ├──►│ zeroshot intent │ + │ │ │ │ │ │ │ │ + └──────────────────────┘ └────────┬────────┘ └────────────────┘ └─────────────────┘ + │ ▲ + │ │ + │ │ + │ ┌────────┴───────┐ + │ │ │ + └───────────►│ arch guard │ + │ │ + └────────────────┘ + + + continue from zeroshot intent + + graph LR + + zeroshot intent --> arch_fc + zeroshot intent --> default prompt target + arch_fc --> developer api call & hallucination check + hallucination check --> parameter gathering & developer api call + developer api call --> resume request to llm + + + ┌─────────────────┐ ┌───────────────────────┐ ┌─────────────────────┐ ┌───────────────────────┐ + │ │ │ │ │ │ │ │ + │ zeroshot intent ├──►│ arch_fc ├──►│ developer api call ├──►│ resume request to llm │ + │ │ │ │ │ │ │ │ + └────────┬────────┘ └───────────┬───────────┘ └─────────────────────┘ └───────────────────────┘ + │ │ ▲ + │ └─────────────┐ │ + │ │ │ + │ ┌───────────────────────┐ │ ┌──────────┴──────────┐ ┌───────────────────────┐ + │ │ │ │ │ │ │ │ + └───────────►│ default prompt target │ └▲│ hallucination check ├──►│ parameter gathering │ + │ │ │ │ │ │ + └───────────────────────┘ └─────────────────────┘ └───────────────────────┘ + + + using https://mermaid-ascii.art/ + */ + + if let Some(body) = self.get_http_call_response_body(0, body_size) { + match callout_context.response_handler_type { + ResponseHandlerType::GetEmbeddings => { + self.embeddings_handler(body, callout_context) + } + ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context), + ResponseHandlerType::ZeroShotIntent => { + self.zero_shot_intent_detection_resp_handler(body, callout_context) + } + ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context), + ResponseHandlerType::HallucinationDetect => { + self.hallucination_classification_resp_handler(body, callout_context) + } + ResponseHandlerType::FunctionCall => { + self.function_call_response_handler(body, callout_context) + } + ResponseHandlerType::DefaultTarget => { + self.default_target_handler(body, callout_context) + } + } + } else { + self.send_server_error( + ServerError::LogicError(String::from("No response body in inline HTTP request")), + None, + ); + } + } +} diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs index b60191a5..3f1d3f0d 100644 --- a/crates/prompt_gateway/src/filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -1,9 +1,9 @@ use crate::stream_context::StreamContext; use common::common_types::EmbeddingType; -use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST}; use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget}; use common::consts::ARCH_UPSTREAM_HOST_HEADER; use common::consts::DEFAULT_EMBEDDING_MODEL; +use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST}; use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; diff --git a/crates/prompt_gateway/src/hallucination.rs b/crates/prompt_gateway/src/hallucination.rs index 99d487ab..71d1e7cf 100644 --- a/crates/prompt_gateway/src/hallucination.rs +++ b/crates/prompt_gateway/src/hallucination.rs @@ -1,12 +1,12 @@ use common::{common_types::open_ai::Message, consts::USER_ROLE}; -pub fn extract_messages_for_hallucination(messages: &Vec) -> Vec { +pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec { let all_user_messages = messages .iter() .filter(|m| m.role == USER_ROLE) .map(|m| m.content.as_ref().unwrap().clone()) .collect::>(); - return all_user_messages; + all_user_messages } #[cfg(test)] @@ -17,7 +17,7 @@ mod test { #[test] fn test_hallucination_message() { - let test_str = r#" + let test_str = r#" [ { "role": "system", @@ -32,8 +32,8 @@ mod test { ] "#; - let messages: Vec = serde_json::from_str(test_str).unwrap(); - let messages_for_halluncination = extract_messages_for_hallucination(&messages); - assert_eq!(messages_for_halluncination.len(), 2); + let messages: Vec = serde_json::from_str(test_str).unwrap(); + let messages_for_halluncination = extract_messages_for_hallucination(&messages); + assert_eq!(messages_for_halluncination.len(), 2); } } diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs new file mode 100644 index 00000000..04fc976e --- /dev/null +++ b/crates/prompt_gateway/src/http_context.rs @@ -0,0 +1,333 @@ +use std::{collections::HashMap, time::Duration}; + +use common::{ + common_types::{ + open_ai::{ + ArchState, ChatCompletionsRequest, ChatCompletionsResponse, Message, StreamOptions, + }, + PromptGuardRequest, PromptGuardTask, + }, + consts::{ + ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER, + ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, GUARD_INTERNAL_HOST, + REQUEST_ID_HEADER, TOOL_ROLE, USER_ROLE, + }, + errors::ServerError, + http::{CallArgs, Client}, +}; +use http::StatusCode; +use log::{debug, warn}; +use proxy_wasm::{traits::HttpContext, types::Action}; +use serde_json::Value; + +use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext}; + +// HttpContext is the trait that allows the Rust code to interact with HTTP objects. +impl HttpContext for StreamContext { + // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto + // the lifecycle of the http request and response. + fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + // Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it. + // Server's generally throw away requests whose body length do not match the Content-Length header. + // However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could + // manipulate the body in benign ways e.g., compression. + self.set_http_request_header("content-length", None); + + self.is_chat_completions_request = + self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH; + + debug!( + "on_http_request_headers S[{}] req_headers={:?}", + self.context_id, + self.get_http_request_headers() + ); + + self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); + + Action::Continue + } + + fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + // Let the client send the gateway all the data before sending to the LLM_provider. + // TODO: consider a streaming API. + if !end_of_stream { + return Action::Pause; + } + + if body_size == 0 { + return Action::Continue; + } + + self.request_body_size = body_size; + + debug!( + "on_http_request_body S[{}] body_size={}", + self.context_id, body_size + ); + + // Deserialize body into spec. + // Currently OpenAI API. + let mut deserialized_body: ChatCompletionsRequest = + match self.get_http_request_body(0, body_size) { + Some(body_bytes) => match serde_json::from_slice(&body_bytes) { + Ok(deserialized) => deserialized, + Err(e) => { + self.send_server_error( + ServerError::Deserialization(e), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + }, + None => { + self.send_server_error( + ServerError::LogicError(format!( + "Failed to obtain body bytes even though body_size is {}", + body_size + )), + None, + ); + return Action::Pause; + } + }; + + self.arch_state = match deserialized_body.metadata { + Some(ref metadata) => { + if metadata.contains_key(ARCH_STATE_HEADER) { + let arch_state_str = metadata[ARCH_STATE_HEADER].clone(); + let arch_state: Vec = serde_json::from_str(&arch_state_str).unwrap(); + Some(arch_state) + } else { + None + } + } + None => None, + }; + + self.streaming_response = deserialized_body.stream; + if deserialized_body.stream && deserialized_body.stream_options.is_none() { + deserialized_body.stream_options = Some(StreamOptions { + include_usage: true, + }); + } + + let last_user_prompt = match deserialized_body + .messages + .iter() + .filter(|msg| msg.role == USER_ROLE) + .last() + { + Some(content) => content, + None => { + warn!("No messages in the request body"); + return Action::Continue; + } + }; + + self.user_prompt = Some(last_user_prompt.clone()); + + let user_message_str = self.user_prompt.as_ref().unwrap().content.clone(); + + let prompt_guard_jailbreak_task = self + .prompt_guards + .input_guards + .contains_key(&common::configuration::GuardType::Jailbreak); + + self.chat_completions_request = Some(deserialized_body); + + if !prompt_guard_jailbreak_task { + debug!("Missing input guard. Making inline call to retrieve embeddings"); + let callout_context = StreamCallContext { + response_handler_type: ResponseHandlerType::ArchGuard, + user_message: user_message_str.clone(), + prompt_target_name: None, + request_body: self.chat_completions_request.as_ref().unwrap().clone(), + similarity_scores: None, + upstream_cluster: None, + upstream_cluster_path: None, + tool_calls: None, + }; + self.get_embeddings(callout_context); + return Action::Pause; + } + + let get_prompt_guards_request = PromptGuardRequest { + input: self + .user_prompt + .as_ref() + .unwrap() + .content + .as_ref() + .unwrap() + .clone(), + task: PromptGuardTask::Jailbreak, + }; + + let json_data: String = match serde_json::to_string(&get_prompt_guards_request) { + Ok(json_data) => json_data, + Err(error) => { + self.send_server_error(ServerError::Serialization(error), None); + return Action::Pause; + } + }; + + let mut headers = vec![ + (ARCH_UPSTREAM_HOST_HEADER, GUARD_INTERNAL_HOST), + (":method", "POST"), + (":path", "/guard"), + (":authority", GUARD_INTERNAL_HOST), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]; + + if self.request_id.is_some() { + headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); + } + + let call_args = CallArgs::new( + ARCH_INTERNAL_CLUSTER_NAME, + "/guard", + headers, + Some(json_data.as_bytes()), + vec![], + Duration::from_secs(5), + ); + let call_context = StreamCallContext { + response_handler_type: ResponseHandlerType::ArchGuard, + user_message: self.user_prompt.as_ref().unwrap().content.clone(), + prompt_target_name: None, + request_body: self.chat_completions_request.as_ref().unwrap().clone(), + similarity_scores: None, + upstream_cluster: None, + upstream_cluster_path: None, + tool_calls: None, + }; + + if let Err(e) = self.http_call(call_args, call_context) { + self.send_server_error(ServerError::HttpDispatch(e), None); + } + + Action::Pause + } + + fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + debug!( + "on_http_response_headers recv [S={}] headers={:?}", + self.context_id, + self.get_http_response_headers() + ); + // delete content-lenght header let envoy calculate it, because we modify the response body + // that would result in a different content-length + self.set_http_response_header("content-length", None); + Action::Continue + } + + fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + debug!( + "recv [S={}] bytes={} end_stream={}", + self.context_id, body_size, end_of_stream + ); + + if !self.is_chat_completions_request { + if let Some(body_str) = self + .get_http_response_body(0, body_size) + .and_then(|bytes| String::from_utf8(bytes).ok()) + { + debug!("recv [S={}] body_str={}", self.context_id, body_str); + } + return Action::Continue; + } + + if !end_of_stream { + return Action::Pause; + } + + let body = self + .get_http_response_body(0, body_size) + .expect("cant get response body"); + + if self.streaming_response { + debug!("streaming response"); + } else { + debug!("non streaming response"); + let chat_completions_response: ChatCompletionsResponse = + match serde_json::from_slice(&body) { + Ok(de) => de, + Err(e) => { + debug!( + "invalid response: {}, {}", + String::from_utf8_lossy(&body), + e + ); + return Action::Continue; + } + }; + + if chat_completions_response.usage.is_some() { + self.response_tokens += chat_completions_response + .usage + .as_ref() + .unwrap() + .completion_tokens; + } + + if let Some(tool_calls) = self.tool_calls.as_ref() { + if !tool_calls.is_empty() { + if self.arch_state.is_none() { + self.arch_state = Some(Vec::new()); + } + + let mut data = serde_json::from_slice(&body).unwrap(); + // use serde::Value to manipulate the json object and ensure that we don't lose any data + if let Value::Object(ref mut map) = data { + // serialize arch state and add to metadata + let metadata = map + .entry("metadata") + .or_insert(Value::Object(serde_json::Map::new())); + if metadata == &Value::Null { + *metadata = Value::Object(serde_json::Map::new()); + } + + // since arch gateway generates tool calls (using arch-fc) and calls upstream api to + // get response, we will send these back to developer so they can see the api response + // and tool call arch-fc generated + let fc_messages = vec![ + Message { + role: ASSISTANT_ROLE.to_string(), + content: None, + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: self.tool_calls.clone(), + tool_call_id: None, + }, + Message { + role: TOOL_ROLE.to_string(), + content: self.tool_call_response.clone(), + model: None, + tool_calls: None, + tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()), + }, + ]; + let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); + let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); + let arch_state_str = serde_json::to_string(&arch_state).unwrap(); + metadata.as_object_mut().unwrap().insert( + ARCH_STATE_HEADER.to_string(), + serde_json::Value::String(arch_state_str), + ); + let data_serialized = serde_json::to_string(&data).unwrap(); + debug!("arch => user: {}", data_serialized); + self.set_http_response_body(0, body_size, data_serialized.as_bytes()); + }; + } + } + } + + debug!( + "recv [S={}] total_tokens={} end_stream={}", + self.context_id, self.response_tokens, end_of_stream + ); + + Action::Continue + } +} diff --git a/crates/prompt_gateway/src/lib.rs b/crates/prompt_gateway/src/lib.rs index 7ca26e44..f873b9bf 100644 --- a/crates/prompt_gateway/src/lib.rs +++ b/crates/prompt_gateway/src/lib.rs @@ -2,9 +2,11 @@ use filter_context::FilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; +mod context; mod filter_context; -mod stream_context; mod hallucination; +mod http_context; +mod stream_context; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 0364cd87..9b99409c 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -3,47 +3,42 @@ use crate::hallucination::extract_messages_for_hallucination; use acap::cos; use common::common_types::open_ai::{ ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, - FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, - StreamOptions, ToolCall, ToolType, + FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, ToolCall, + ToolType, }; use common::common_types::{ EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, - PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest, - ZeroShotClassificationResponse, + PromptGuardResponse, ZeroShotClassificationRequest, ZeroShotClassificationResponse, }; use common::configuration::{Overrides, PromptGuards, PromptTarget}; use common::consts::{ ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, - ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, - DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, - EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, - REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST, + ARCH_UPSTREAM_HOST_HEADER, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, + DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, + HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, + ZEROSHOT_INTERNAL_HOST, }; use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; -use common::errors::ClientError; +use common::errors::ServerError; use common::http::{CallArgs, Client}; use common::stats::Gauge; use derivative::Derivative; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::traits::*; -use proxy_wasm::types::*; -use serde_json::Value; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; use std::str::FromStr; use std::time::Duration; -use common::stats::IncrementingMetric; - #[derive(Debug, Clone)] -enum ResponseHandlerType { +pub enum ResponseHandlerType { GetEmbeddings, - FunctionResolver, + ArchFC, FunctionCall, ZeroShotIntent, HallucinationDetect, @@ -54,59 +49,36 @@ enum ResponseHandlerType { #[derive(Clone, Derivative)] #[derivative(Debug)] pub struct StreamCallContext { - response_handler_type: ResponseHandlerType, - user_message: Option, - prompt_target_name: Option, + pub response_handler_type: ResponseHandlerType, + pub user_message: Option, + pub prompt_target_name: Option, #[derivative(Debug = "ignore")] - request_body: ChatCompletionsRequest, - tool_calls: Option>, - similarity_scores: Option>, - upstream_cluster: Option, - upstream_cluster_path: Option, -} - -#[derive(thiserror::Error, Debug)] -pub enum ServerError { - #[error(transparent)] - HttpDispatch(ClientError), - #[error(transparent)] - Deserialization(serde_json::Error), - #[error(transparent)] - Serialization(serde_json::Error), - #[error("{0}")] - LogicError(String), - #[error("upstream application error host={host}, path={path}, status={status}, body={body}")] - Upstream { - host: String, - path: String, - status: String, - body: String, - }, - #[error("jailbreak detected: {0}")] - Jailbreak(String), - #[error("{why}")] - NoMessagesFound { why: String }, + pub request_body: ChatCompletionsRequest, + pub tool_calls: Option>, + pub similarity_scores: Option>, + pub upstream_cluster: Option, + pub upstream_cluster_path: Option, } pub struct StreamContext { - context_id: u32, - metrics: Rc, system_prompt: Rc>, prompt_targets: Rc>, embeddings_store: Option>, overrides: Rc>, - callouts: RefCell>, - tool_calls: Option>, - tool_call_response: Option, - arch_state: Option>, - request_body_size: usize, - streaming_response: bool, - user_prompt: Option, - response_tokens: usize, - is_chat_completions_request: bool, - chat_completions_request: Option, - prompt_guards: Rc, - request_id: Option, + pub metrics: Rc, + pub callouts: RefCell>, + pub context_id: u32, + pub tool_calls: Option>, + pub tool_call_response: Option, + pub arch_state: Option>, + pub request_body_size: usize, + pub streaming_response: bool, + pub user_prompt: Option, + pub response_tokens: usize, + pub is_chat_completions_request: bool, + pub chat_completions_request: Option, + pub prompt_guards: Rc, + pub request_id: Option, } impl StreamContext { @@ -146,15 +118,7 @@ impl StreamContext { .expect("embeddings store is not set") } - fn delete_content_length_header(&mut self) { - // Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it. - // Server's generally throw away requests whose body length do not match the Content-Length header. - // However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could - // manipulate the body in benign ways e.g., compression. - self.set_http_request_header("content-length", None); - } - - fn send_server_error(&self, error: ServerError, override_status_code: Option) { + pub fn send_server_error(&self, error: ServerError, override_status_code: Option) { self.send_http_response( override_status_code .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) @@ -165,7 +129,7 @@ impl StreamContext { ); } - fn embeddings_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { + pub fn embeddings_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { Ok(embedding_response) => embedding_response, Err(e) => { @@ -278,7 +242,7 @@ impl StreamContext { } } - fn hallucination_classification_resp_handler( + pub fn hallucination_classification_resp_handler( &mut self, body: Vec, callout_context: StreamCallContext, @@ -343,7 +307,7 @@ impl StreamContext { } } - fn zero_shot_intent_detection_resp_handler( + pub fn zero_shot_intent_detection_resp_handler( &mut self, body: Vec, mut callout_context: StreamCallContext, @@ -585,7 +549,7 @@ impl StreamContext { vec![], Duration::from_secs(5), ); - callout_context.response_handler_type = ResponseHandlerType::FunctionResolver; + callout_context.response_handler_type = ResponseHandlerType::ArchFC; callout_context.prompt_target_name = Some(prompt_target.name); if let Err(e) = self.http_call(call_args, callout_context) { @@ -594,7 +558,11 @@ impl StreamContext { } } - fn function_resolver_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { + pub fn arch_fc_response_handler( + &mut self, + body: Vec, + mut callout_context: StreamCallContext, + ) { let body_str = String::from_utf8(body).unwrap(); debug!("arch <= app response body: {}", body_str); @@ -782,7 +750,7 @@ impl StreamContext { } } - fn function_call_response_handler( + pub fn function_call_response_handler( &mut self, body: Vec, callout_context: StreamCallContext, @@ -892,7 +860,7 @@ impl StreamContext { self.resume_http_request(); } - fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { + pub fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { debug!("response received for arch guard"); let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); debug!("prompt_guard_resp: {:?}", prompt_guard_resp); @@ -913,7 +881,7 @@ impl StreamContext { self.get_embeddings(callout_context); } - fn get_embeddings(&mut self, callout_context: StreamCallContext) { + pub fn get_embeddings(&mut self, callout_context: StreamCallContext) { let user_message = callout_context.user_message.unwrap(); let get_embeddings_input = CreateEmbeddingRequest { // Need to clone into input because user_message is used below. @@ -969,7 +937,7 @@ impl StreamContext { } } - fn default_target_handler(&self, body: Vec, callout_context: StreamCallContext) { + pub fn default_target_handler(&self, body: Vec, callout_context: StreamCallContext) { let prompt_target = self .prompt_targets .get(callout_context.prompt_target_name.as_ref().unwrap()) @@ -1046,357 +1014,6 @@ impl StreamContext { } } -// HttpContext is the trait that allows the Rust code to interact with HTTP objects. -impl HttpContext for StreamContext { - // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto - // the lifecycle of the http request and response. - fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - self.delete_content_length_header(); - - self.is_chat_completions_request = - self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH; - - debug!( - "on_http_request_headers S[{}] req_headers={:?}", - self.context_id, - self.get_http_request_headers() - ); - - self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); - - Action::Continue - } - - fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { - // Let the client send the gateway all the data before sending to the LLM_provider. - // TODO: consider a streaming API. - if !end_of_stream { - return Action::Pause; - } - - if body_size == 0 { - return Action::Continue; - } - - self.request_body_size = body_size; - - debug!( - "on_http_request_body S[{}] body_size={}", - self.context_id, body_size - ); - - // Deserialize body into spec. - // Currently OpenAI API. - let mut deserialized_body: ChatCompletionsRequest = - match self.get_http_request_body(0, body_size) { - Some(body_bytes) => match serde_json::from_slice(&body_bytes) { - Ok(deserialized) => deserialized, - Err(e) => { - self.send_server_error( - ServerError::Deserialization(e), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; - } - }, - None => { - self.send_server_error( - ServerError::LogicError(format!( - "Failed to obtain body bytes even though body_size is {}", - body_size - )), - None, - ); - return Action::Pause; - } - }; - - self.arch_state = match deserialized_body.metadata { - Some(ref metadata) => { - if metadata.contains_key(ARCH_STATE_HEADER) { - let arch_state_str = metadata[ARCH_STATE_HEADER].clone(); - let arch_state: Vec = serde_json::from_str(&arch_state_str).unwrap(); - Some(arch_state) - } else { - None - } - } - None => None, - }; - - self.streaming_response = deserialized_body.stream; - if deserialized_body.stream && deserialized_body.stream_options.is_none() { - deserialized_body.stream_options = Some(StreamOptions { - include_usage: true, - }); - } - - let last_user_prompt = match deserialized_body - .messages - .iter() - .filter(|msg| msg.role == USER_ROLE) - .last() - { - Some(content) => content, - None => { - warn!("No messages in the request body"); - return Action::Continue; - } - }; - - self.user_prompt = Some(last_user_prompt.clone()); - - let user_message_str = self.user_prompt.as_ref().unwrap().content.clone(); - - let prompt_guard_jailbreak_task = self - .prompt_guards - .input_guards - .contains_key(&common::configuration::GuardType::Jailbreak); - - self.chat_completions_request = Some(deserialized_body); - - if !prompt_guard_jailbreak_task { - debug!("Missing input guard. Making inline call to retrieve embeddings"); - let callout_context = StreamCallContext { - response_handler_type: ResponseHandlerType::ArchGuard, - user_message: user_message_str.clone(), - prompt_target_name: None, - request_body: self.chat_completions_request.as_ref().unwrap().clone(), - similarity_scores: None, - upstream_cluster: None, - upstream_cluster_path: None, - tool_calls: None, - }; - self.get_embeddings(callout_context); - return Action::Pause; - } - - let get_prompt_guards_request = PromptGuardRequest { - input: self - .user_prompt - .as_ref() - .unwrap() - .content - .as_ref() - .unwrap() - .clone(), - task: PromptGuardTask::Jailbreak, - }; - - let json_data: String = match serde_json::to_string(&get_prompt_guards_request) { - Ok(json_data) => json_data, - Err(error) => { - self.send_server_error(ServerError::Serialization(error), None); - return Action::Pause; - } - }; - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, GUARD_INTERNAL_HOST), - (":method", "POST"), - (":path", "/guard"), - (":authority", GUARD_INTERNAL_HOST), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/guard", - headers, - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ); - let call_context = StreamCallContext { - response_handler_type: ResponseHandlerType::ArchGuard, - user_message: self.user_prompt.as_ref().unwrap().content.clone(), - prompt_target_name: None, - request_body: self.chat_completions_request.as_ref().unwrap().clone(), - similarity_scores: None, - upstream_cluster: None, - upstream_cluster_path: None, - tool_calls: None, - }; - - if let Err(e) = self.http_call(call_args, call_context) { - self.send_server_error(ServerError::HttpDispatch(e), None); - } - - Action::Pause - } - - fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - debug!( - "on_http_response_headers recv [S={}] headers={:?}", - self.context_id, - self.get_http_response_headers() - ); - // delete content-lenght header let envoy calculate it, because we modify the response body - // that would result in a different content-length - self.set_http_response_header("content-length", None); - Action::Continue - } - - fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { - debug!( - "recv [S={}] bytes={} end_stream={}", - self.context_id, body_size, end_of_stream - ); - - if !self.is_chat_completions_request { - if let Some(body_str) = self - .get_http_response_body(0, body_size) - .and_then(|bytes| String::from_utf8(bytes).ok()) - { - debug!("recv [S={}] body_str={}", self.context_id, body_str); - } - return Action::Continue; - } - - if !end_of_stream { - return Action::Pause; - } - - let body = self - .get_http_response_body(0, body_size) - .expect("cant get response body"); - - if self.streaming_response { - debug!("streaming response"); - } else { - debug!("non streaming response"); - let chat_completions_response: ChatCompletionsResponse = - match serde_json::from_slice(&body) { - Ok(de) => de, - Err(e) => { - debug!( - "invalid response: {}, {}", - String::from_utf8_lossy(&body), - e - ); - return Action::Continue; - } - }; - - if chat_completions_response.usage.is_some() { - self.response_tokens += chat_completions_response - .usage - .as_ref() - .unwrap() - .completion_tokens; - } - - if let Some(tool_calls) = self.tool_calls.as_ref() { - if !tool_calls.is_empty() { - if self.arch_state.is_none() { - self.arch_state = Some(Vec::new()); - } - - let mut data: Value = serde_json::from_slice(&body).unwrap(); - // use serde::Value to manipulate the json object and ensure that we don't lose any data - if let Value::Object(ref mut map) = data { - // serialize arch state and add to metadata - let metadata = map - .entry("metadata") - .or_insert(Value::Object(serde_json::Map::new())); - if metadata == &Value::Null { - *metadata = Value::Object(serde_json::Map::new()); - } - - // since arch gateway generates tool calls (using arch-fc) and calls upstream api to - // get response, we will send these back to developer so they can see the api response - // and tool call arch-fc generated - let mut fc_messages = Vec::new(); - fc_messages.push(Message { - role: ASSISTANT_ROLE.to_string(), - content: None, - model: Some(ARCH_FC_MODEL_NAME.to_string()), - tool_calls: self.tool_calls.clone(), - tool_call_id: None, - }); - fc_messages.push(Message { - role: TOOL_ROLE.to_string(), - content: self.tool_call_response.clone(), - model: None, - tool_calls: None, - tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()), - }); - let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); - let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); - let arch_state_str = serde_json::to_string(&arch_state).unwrap(); - metadata.as_object_mut().unwrap().insert( - ARCH_STATE_HEADER.to_string(), - serde_json::Value::String(arch_state_str), - ); - let data_serialized = serde_json::to_string(&data).unwrap(); - debug!("arch => user: {}", data_serialized); - self.set_http_response_body(0, body_size, data_serialized.as_bytes()); - }; - } - } - } - - debug!( - "recv [S={}] total_tokens={} end_stream={}", - self.context_id, self.response_tokens, end_of_stream - ); - - Action::Continue - } -} - -impl Context for StreamContext { - fn on_http_call_response( - &mut self, - token_id: u32, - _num_headers: usize, - body_size: usize, - _num_trailers: usize, - ) { - let callout_context = self - .callouts - .get_mut() - .remove(&token_id) - .expect("invalid token_id"); - self.metrics.active_http_calls.increment(-1); - - if let Some(body) = self.get_http_call_response_body(0, body_size) { - match callout_context.response_handler_type { - ResponseHandlerType::GetEmbeddings => { - self.embeddings_handler(body, callout_context) - } - ResponseHandlerType::ZeroShotIntent => { - self.zero_shot_intent_detection_resp_handler(body, callout_context) - } - ResponseHandlerType::HallucinationDetect => { - self.hallucination_classification_resp_handler(body, callout_context) - } - ResponseHandlerType::FunctionResolver => { - self.function_resolver_handler(body, callout_context) - } - ResponseHandlerType::FunctionCall => { - self.function_call_response_handler(body, callout_context) - } - ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context), - ResponseHandlerType::DefaultTarget => { - self.default_target_handler(body, callout_context) - } - } - } else { - self.send_server_error( - ServerError::LogicError(String::from("No response body in inline HTTP request")), - None, - ); - } - } -} - impl Client for StreamContext { type CallContext = StreamCallContext;