refactor more

This commit is contained in:
Adil Hafeez 2024-10-21 14:29:21 -07:00
parent e70f55dd5b
commit 539860efea
3 changed files with 117 additions and 113 deletions

View file

@ -0,0 +1,103 @@
use common::errors::ServerError;
use common::stats::IncrementingMetric;
use proxy_wasm::traits::Context;
use crate::stream_context::{ResponseHandlerType, StreamContext};
impl Context for StreamContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
let callout_context = self
.callouts
.get_mut()
.remove(&token_id)
.expect("invalid token_id");
self.metrics.active_http_calls.increment(-1);
/*
state transition
graph LR
on_http_request_body --> prompt received
prompt received --> get embeddings & arch guard
arch guard --> get embeddings
get embeddings --> zeroshot intent
on_http_request_body prompt received get embeddings zeroshot intent
arch guard
continue from zeroshot intent
graph LR
zeroshot intent --> arch_fc
zeroshot intent --> default prompt target
arch_fc --> developer api call & hallucination check
hallucination check --> parameter gathering & developer api call
developer api call --> resume request to llm
zeroshot intent arch_fc developer api call resume request to llm
default prompt target hallucination check parameter gathering
using https://mermaid-ascii.art/
*/
if let Some(body) = self.get_http_call_response_body(0, body_size) {
match callout_context.response_handler_type {
ResponseHandlerType::GetEmbeddings => {
self.embeddings_handler(body, callout_context)
}
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
ResponseHandlerType::ZeroShotIntent => {
self.zero_shot_intent_detection_resp_handler(body, callout_context)
}
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
ResponseHandlerType::HallucinationDetect => {
self.hallucination_classification_resp_handler(body, callout_context)
}
ResponseHandlerType::FunctionCall => {
self.function_call_response_handler(body, callout_context)
}
ResponseHandlerType::DefaultTarget => {
self.default_target_handler(body, callout_context)
}
}
} else {
self.send_server_error(
ServerError::LogicError(String::from("No response body in inline HTTP request")),
None,
);
}
}
}

View file

@ -2,6 +2,7 @@ use filter_context::FilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod context;
mod filter_context;
mod hallucination;
mod http_context;

View file

@ -24,7 +24,7 @@ use common::embeddings::{
};
use common::errors::ServerError;
use common::http::{CallArgs, Client};
use common::stats::{Gauge, IncrementingMetric};
use common::stats::Gauge;
use derivative::Derivative;
use http::StatusCode;
use log::{debug, info, warn};
@ -61,12 +61,12 @@ pub struct StreamCallContext {
}
pub struct StreamContext {
metrics: Rc<WasmMetrics>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
overrides: Rc<Option<Overrides>>,
callouts: RefCell<HashMap<u32, StreamCallContext>>,
pub metrics: Rc<WasmMetrics>,
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
pub context_id: u32,
pub tool_calls: Option<Vec<ToolCall>>,
pub tool_call_response: Option<String>,
@ -242,7 +242,7 @@ impl StreamContext {
}
}
fn hallucination_classification_resp_handler(
pub fn hallucination_classification_resp_handler(
&mut self,
body: Vec<u8>,
callout_context: StreamCallContext,
@ -307,7 +307,7 @@ impl StreamContext {
}
}
fn zero_shot_intent_detection_resp_handler(
pub fn zero_shot_intent_detection_resp_handler(
&mut self,
body: Vec<u8>,
mut callout_context: StreamCallContext,
@ -558,7 +558,11 @@ impl StreamContext {
}
}
fn arch_fc_response_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) {
pub fn arch_fc_response_handler(
&mut self,
body: Vec<u8>,
mut callout_context: StreamCallContext,
) {
let body_str = String::from_utf8(body).unwrap();
debug!("arch <= app response body: {}", body_str);
@ -746,7 +750,7 @@ impl StreamContext {
}
}
fn function_call_response_handler(
pub fn function_call_response_handler(
&mut self,
body: Vec<u8>,
callout_context: StreamCallContext,
@ -856,7 +860,7 @@ impl StreamContext {
self.resume_http_request();
}
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
debug!("response received for arch guard");
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
@ -933,7 +937,7 @@ impl StreamContext {
}
}
fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
pub fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
let prompt_target = self
.prompt_targets
.get(callout_context.prompt_target_name.as_ref().unwrap())
@ -1010,110 +1014,6 @@ impl StreamContext {
}
}
impl Context for StreamContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
let callout_context = self
.callouts
.get_mut()
.remove(&token_id)
.expect("invalid token_id");
self.metrics.active_http_calls.increment(-1);
// state transition
/*
graph LR
on_http_request_body --> prompt received
prompt received --> get embeddings & arch guard
arch guard --> get embeddings
get embeddings --> zeroshot intent
on_http_request_body prompt received get embeddings zeroshot intent
arch guard
continue from zeroshot intent
graph LR
zeroshot intent --> arch_fc
zeroshot intent --> default prompt target
arch_fc --> developer api call & hallucination check
hallucination check --> parameter gathering & developer api call
developer api call --> resume request to llm
zeroshot intent arch_fc developer api call resume request to llm
default prompt target hallucination check parameter gathering
using https://mermaid-ascii.art/
*/
if let Some(body) = self.get_http_call_response_body(0, body_size) {
match callout_context.response_handler_type {
ResponseHandlerType::GetEmbeddings => {
self.embeddings_handler(body, callout_context)
}
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
ResponseHandlerType::ZeroShotIntent => {
self.zero_shot_intent_detection_resp_handler(body, callout_context)
}
ResponseHandlerType::ArchFC => {
self.arch_fc_response_handler(body, callout_context)
}
ResponseHandlerType::HallucinationDetect => {
self.hallucination_classification_resp_handler(body, callout_context)
}
ResponseHandlerType::FunctionCall => {
self.function_call_response_handler(body, callout_context)
}
ResponseHandlerType::DefaultTarget => {
self.default_target_handler(body, callout_context)
}
}
} else {
self.send_server_error(
ServerError::LogicError(String::from("No response body in inline HTTP request")),
None,
);
}
}
}
impl Client for StreamContext {
type CallContext = StreamCallContext;