diff --git a/crates/prompt_gateway/src/context.rs b/crates/prompt_gateway/src/context.rs new file mode 100644 index 00000000..f33d6238 --- /dev/null +++ b/crates/prompt_gateway/src/context.rs @@ -0,0 +1,103 @@ +use common::errors::ServerError; +use common::stats::IncrementingMetric; +use proxy_wasm::traits::Context; + +use crate::stream_context::{ResponseHandlerType, StreamContext}; + +impl Context for StreamContext { + fn on_http_call_response( + &mut self, + token_id: u32, + _num_headers: usize, + body_size: usize, + _num_trailers: usize, + ) { + let callout_context = self + .callouts + .get_mut() + .remove(&token_id) + .expect("invalid token_id"); + self.metrics.active_http_calls.increment(-1); + + /* + state transition + + graph LR + + on_http_request_body --> prompt received + prompt received --> get embeddings & arch guard + arch guard --> get embeddings + get embeddings --> zeroshot intent + + ┌──────────────────────┐ ┌─────────────────┐ ┌────────────────┐ ┌─────────────────┐ + │ │ │ │ │ │ │ │ + │ on_http_request_body ├──►│ prompt received ├──►│ get embeddings ├──►│ zeroshot intent │ + │ │ │ │ │ │ │ │ + └──────────────────────┘ └────────┬────────┘ └────────────────┘ └─────────────────┘ + │ ▲ + │ │ + │ │ + │ ┌────────┴───────┐ + │ │ │ + └───────────►│ arch guard │ + │ │ + └────────────────┘ + + + continue from zeroshot intent + + graph LR + + zeroshot intent --> arch_fc + zeroshot intent --> default prompt target + arch_fc --> developer api call & hallucination check + hallucination check --> parameter gathering & developer api call + developer api call --> resume request to llm + + + ┌─────────────────┐ ┌───────────────────────┐ ┌─────────────────────┐ ┌───────────────────────┐ + │ │ │ │ │ │ │ │ + │ zeroshot intent ├──►│ arch_fc ├──►│ developer api call ├──►│ resume request to llm │ + │ │ │ │ │ │ │ │ + └────────┬────────┘ └───────────┬───────────┘ └─────────────────────┘ └───────────────────────┘ + │ │ ▲ + │ └─────────────┐ │ + │ │ │ + │ ┌───────────────────────┐ │ ┌──────────┴──────────┐ ┌───────────────────────┐ + │ │ │ │ │ │ │ │ + └───────────►│ default prompt target │ └▲│ hallucination check ├──►│ parameter gathering │ + │ │ │ │ │ │ + └───────────────────────┘ └─────────────────────┘ └───────────────────────┘ + + + using https://mermaid-ascii.art/ + */ + + if let Some(body) = self.get_http_call_response_body(0, body_size) { + match callout_context.response_handler_type { + ResponseHandlerType::GetEmbeddings => { + self.embeddings_handler(body, callout_context) + } + ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context), + ResponseHandlerType::ZeroShotIntent => { + self.zero_shot_intent_detection_resp_handler(body, callout_context) + } + ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context), + ResponseHandlerType::HallucinationDetect => { + self.hallucination_classification_resp_handler(body, callout_context) + } + ResponseHandlerType::FunctionCall => { + self.function_call_response_handler(body, callout_context) + } + ResponseHandlerType::DefaultTarget => { + self.default_target_handler(body, callout_context) + } + } + } else { + self.send_server_error( + ServerError::LogicError(String::from("No response body in inline HTTP request")), + None, + ); + } + } +} diff --git a/crates/prompt_gateway/src/lib.rs b/crates/prompt_gateway/src/lib.rs index f3305883..f873b9bf 100644 --- a/crates/prompt_gateway/src/lib.rs +++ b/crates/prompt_gateway/src/lib.rs @@ -2,6 +2,7 @@ use filter_context::FilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; +mod context; mod filter_context; mod hallucination; mod http_context; diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 32e61983..9b99409c 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -24,7 +24,7 @@ use common::embeddings::{ }; use common::errors::ServerError; use common::http::{CallArgs, Client}; -use common::stats::{Gauge, IncrementingMetric}; +use common::stats::Gauge; use derivative::Derivative; use http::StatusCode; use log::{debug, info, warn}; @@ -61,12 +61,12 @@ pub struct StreamCallContext { } pub struct StreamContext { - metrics: Rc, system_prompt: Rc>, prompt_targets: Rc>, embeddings_store: Option>, overrides: Rc>, - callouts: RefCell>, + pub metrics: Rc, + pub callouts: RefCell>, pub context_id: u32, pub tool_calls: Option>, pub tool_call_response: Option, @@ -242,7 +242,7 @@ impl StreamContext { } } - fn hallucination_classification_resp_handler( + pub fn hallucination_classification_resp_handler( &mut self, body: Vec, callout_context: StreamCallContext, @@ -307,7 +307,7 @@ impl StreamContext { } } - fn zero_shot_intent_detection_resp_handler( + pub fn zero_shot_intent_detection_resp_handler( &mut self, body: Vec, mut callout_context: StreamCallContext, @@ -558,7 +558,11 @@ impl StreamContext { } } - fn arch_fc_response_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { + pub fn arch_fc_response_handler( + &mut self, + body: Vec, + mut callout_context: StreamCallContext, + ) { let body_str = String::from_utf8(body).unwrap(); debug!("arch <= app response body: {}", body_str); @@ -746,7 +750,7 @@ impl StreamContext { } } - fn function_call_response_handler( + pub fn function_call_response_handler( &mut self, body: Vec, callout_context: StreamCallContext, @@ -856,7 +860,7 @@ impl StreamContext { self.resume_http_request(); } - fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { + pub fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { debug!("response received for arch guard"); let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); debug!("prompt_guard_resp: {:?}", prompt_guard_resp); @@ -933,7 +937,7 @@ impl StreamContext { } } - fn default_target_handler(&self, body: Vec, callout_context: StreamCallContext) { + pub fn default_target_handler(&self, body: Vec, callout_context: StreamCallContext) { let prompt_target = self .prompt_targets .get(callout_context.prompt_target_name.as_ref().unwrap()) @@ -1010,110 +1014,6 @@ impl StreamContext { } } -impl Context for StreamContext { - fn on_http_call_response( - &mut self, - token_id: u32, - _num_headers: usize, - body_size: usize, - _num_trailers: usize, - ) { - let callout_context = self - .callouts - .get_mut() - .remove(&token_id) - .expect("invalid token_id"); - self.metrics.active_http_calls.increment(-1); - -// state transition - -/* - -graph LR - -on_http_request_body --> prompt received -prompt received --> get embeddings & arch guard -arch guard --> get embeddings -get embeddings --> zeroshot intent - -┌──────────────────────┐ ┌─────────────────┐ ┌────────────────┐ ┌─────────────────┐ -│ │ │ │ │ │ │ │ -│ on_http_request_body ├──►│ prompt received ├──►│ get embeddings ├──►│ zeroshot intent │ -│ │ │ │ │ │ │ │ -└──────────────────────┘ └────────┬────────┘ └────────────────┘ └─────────────────┘ - │ ▲ - │ │ - │ │ - │ ┌────────┴───────┐ - │ │ │ - └───────────►│ arch guard │ - │ │ - └────────────────┘ - - -continue from zeroshot intent - -graph LR - -zeroshot intent --> arch_fc -zeroshot intent --> default prompt target -arch_fc --> developer api call & hallucination check -hallucination check --> parameter gathering & developer api call -developer api call --> resume request to llm - - -┌─────────────────┐ ┌───────────────────────┐ ┌─────────────────────┐ ┌───────────────────────┐ -│ │ │ │ │ │ │ │ -│ zeroshot intent ├──►│ arch_fc ├──►│ developer api call ├──►│ resume request to llm │ -│ │ │ │ │ │ │ │ -└────────┬────────┘ └───────────┬───────────┘ └─────────────────────┘ └───────────────────────┘ - │ │ ▲ - │ └─────────────┐ │ - │ │ │ - │ ┌───────────────────────┐ │ ┌──────────┴──────────┐ ┌───────────────────────┐ - │ │ │ │ │ │ │ │ - └───────────►│ default prompt target │ └▲│ hallucination check ├──►│ parameter gathering │ - │ │ │ │ │ │ - └───────────────────────┘ └─────────────────────┘ └───────────────────────┘ - - -using https://mermaid-ascii.art/ -*/ - - - - - if let Some(body) = self.get_http_call_response_body(0, body_size) { - match callout_context.response_handler_type { - ResponseHandlerType::GetEmbeddings => { - self.embeddings_handler(body, callout_context) - } - ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context), - ResponseHandlerType::ZeroShotIntent => { - self.zero_shot_intent_detection_resp_handler(body, callout_context) - } - ResponseHandlerType::ArchFC => { - self.arch_fc_response_handler(body, callout_context) - } - ResponseHandlerType::HallucinationDetect => { - self.hallucination_classification_resp_handler(body, callout_context) - } - ResponseHandlerType::FunctionCall => { - self.function_call_response_handler(body, callout_context) - } - ResponseHandlerType::DefaultTarget => { - self.default_target_handler(body, callout_context) - } - } - } else { - self.send_server_error( - ServerError::LogicError(String::from("No response body in inline HTTP request")), - None, - ); - } - } -} - impl Client for StreamContext { type CallContext = StreamCallContext;