mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
refactor more
This commit is contained in:
parent
e70f55dd5b
commit
539860efea
3 changed files with 117 additions and 113 deletions
103
crates/prompt_gateway/src/context.rs
Normal file
103
crates/prompt_gateway/src/context.rs
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
use common::errors::ServerError;
|
||||
use common::stats::IncrementingMetric;
|
||||
use proxy_wasm::traits::Context;
|
||||
|
||||
use crate::stream_context::{ResponseHandlerType, StreamContext};
|
||||
|
||||
impl Context for StreamContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
let callout_context = self
|
||||
.callouts
|
||||
.get_mut()
|
||||
.remove(&token_id)
|
||||
.expect("invalid token_id");
|
||||
self.metrics.active_http_calls.increment(-1);
|
||||
|
||||
/*
|
||||
state transition
|
||||
|
||||
graph LR
|
||||
|
||||
on_http_request_body --> prompt received
|
||||
prompt received --> get embeddings & arch guard
|
||||
arch guard --> get embeddings
|
||||
get embeddings --> zeroshot intent
|
||||
|
||||
┌──────────────────────┐ ┌─────────────────┐ ┌────────────────┐ ┌─────────────────┐
|
||||
│ │ │ │ │ │ │ │
|
||||
│ on_http_request_body ├──►│ prompt received ├──►│ get embeddings ├──►│ zeroshot intent │
|
||||
│ │ │ │ │ │ │ │
|
||||
└──────────────────────┘ └────────┬────────┘ └────────────────┘ └─────────────────┘
|
||||
│ ▲
|
||||
│ │
|
||||
│ │
|
||||
│ ┌────────┴───────┐
|
||||
│ │ │
|
||||
└───────────►│ arch guard │
|
||||
│ │
|
||||
└────────────────┘
|
||||
|
||||
|
||||
continue from zeroshot intent
|
||||
|
||||
graph LR
|
||||
|
||||
zeroshot intent --> arch_fc
|
||||
zeroshot intent --> default prompt target
|
||||
arch_fc --> developer api call & hallucination check
|
||||
hallucination check --> parameter gathering & developer api call
|
||||
developer api call --> resume request to llm
|
||||
|
||||
|
||||
┌─────────────────┐ ┌───────────────────────┐ ┌─────────────────────┐ ┌───────────────────────┐
|
||||
│ │ │ │ │ │ │ │
|
||||
│ zeroshot intent ├──►│ arch_fc ├──►│ developer api call ├──►│ resume request to llm │
|
||||
│ │ │ │ │ │ │ │
|
||||
└────────┬────────┘ └───────────┬───────────┘ └─────────────────────┘ └───────────────────────┘
|
||||
│ │ ▲
|
||||
│ └─────────────┐ │
|
||||
│ │ │
|
||||
│ ┌───────────────────────┐ │ ┌──────────┴──────────┐ ┌───────────────────────┐
|
||||
│ │ │ │ │ │ │ │
|
||||
└───────────►│ default prompt target │ └▲│ hallucination check ├──►│ parameter gathering │
|
||||
│ │ │ │ │ │
|
||||
└───────────────────────┘ └─────────────────────┘ └───────────────────────┘
|
||||
|
||||
|
||||
using https://mermaid-ascii.art/
|
||||
*/
|
||||
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
match callout_context.response_handler_type {
|
||||
ResponseHandlerType::GetEmbeddings => {
|
||||
self.embeddings_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
|
||||
ResponseHandlerType::ZeroShotIntent => {
|
||||
self.zero_shot_intent_detection_resp_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
|
||||
ResponseHandlerType::HallucinationDetect => {
|
||||
self.hallucination_classification_resp_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::FunctionCall => {
|
||||
self.function_call_response_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::DefaultTarget => {
|
||||
self.default_target_handler(body, callout_context)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(String::from("No response body in inline HTTP request")),
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@ use filter_context::FilterContext;
|
|||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
mod context;
|
||||
mod filter_context;
|
||||
mod hallucination;
|
||||
mod http_context;
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ use common::embeddings::{
|
|||
};
|
||||
use common::errors::ServerError;
|
||||
use common::http::{CallArgs, Client};
|
||||
use common::stats::{Gauge, IncrementingMetric};
|
||||
use common::stats::Gauge;
|
||||
use derivative::Derivative;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
|
|
@ -61,12 +61,12 @@ pub struct StreamCallContext {
|
|||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
metrics: Rc<WasmMetrics>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
||||
pub metrics: Rc<WasmMetrics>,
|
||||
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
||||
pub context_id: u32,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
pub tool_call_response: Option<String>,
|
||||
|
|
@ -242,7 +242,7 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn hallucination_classification_resp_handler(
|
||||
pub fn hallucination_classification_resp_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
callout_context: StreamCallContext,
|
||||
|
|
@ -307,7 +307,7 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn zero_shot_intent_detection_resp_handler(
|
||||
pub fn zero_shot_intent_detection_resp_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
mut callout_context: StreamCallContext,
|
||||
|
|
@ -558,7 +558,11 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn arch_fc_response_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) {
|
||||
pub fn arch_fc_response_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
mut callout_context: StreamCallContext,
|
||||
) {
|
||||
let body_str = String::from_utf8(body).unwrap();
|
||||
debug!("arch <= app response body: {}", body_str);
|
||||
|
||||
|
|
@ -746,7 +750,7 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn function_call_response_handler(
|
||||
pub fn function_call_response_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
callout_context: StreamCallContext,
|
||||
|
|
@ -856,7 +860,7 @@ impl StreamContext {
|
|||
self.resume_http_request();
|
||||
}
|
||||
|
||||
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
debug!("response received for arch guard");
|
||||
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
||||
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
|
||||
|
|
@ -933,7 +937,7 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
pub fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||
|
|
@ -1010,110 +1014,6 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
impl Context for StreamContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
let callout_context = self
|
||||
.callouts
|
||||
.get_mut()
|
||||
.remove(&token_id)
|
||||
.expect("invalid token_id");
|
||||
self.metrics.active_http_calls.increment(-1);
|
||||
|
||||
// state transition
|
||||
|
||||
/*
|
||||
|
||||
graph LR
|
||||
|
||||
on_http_request_body --> prompt received
|
||||
prompt received --> get embeddings & arch guard
|
||||
arch guard --> get embeddings
|
||||
get embeddings --> zeroshot intent
|
||||
|
||||
┌──────────────────────┐ ┌─────────────────┐ ┌────────────────┐ ┌─────────────────┐
|
||||
│ │ │ │ │ │ │ │
|
||||
│ on_http_request_body ├──►│ prompt received ├──►│ get embeddings ├──►│ zeroshot intent │
|
||||
│ │ │ │ │ │ │ │
|
||||
└──────────────────────┘ └────────┬────────┘ └────────────────┘ └─────────────────┘
|
||||
│ ▲
|
||||
│ │
|
||||
│ │
|
||||
│ ┌────────┴───────┐
|
||||
│ │ │
|
||||
└───────────►│ arch guard │
|
||||
│ │
|
||||
└────────────────┘
|
||||
|
||||
|
||||
continue from zeroshot intent
|
||||
|
||||
graph LR
|
||||
|
||||
zeroshot intent --> arch_fc
|
||||
zeroshot intent --> default prompt target
|
||||
arch_fc --> developer api call & hallucination check
|
||||
hallucination check --> parameter gathering & developer api call
|
||||
developer api call --> resume request to llm
|
||||
|
||||
|
||||
┌─────────────────┐ ┌───────────────────────┐ ┌─────────────────────┐ ┌───────────────────────┐
|
||||
│ │ │ │ │ │ │ │
|
||||
│ zeroshot intent ├──►│ arch_fc ├──►│ developer api call ├──►│ resume request to llm │
|
||||
│ │ │ │ │ │ │ │
|
||||
└────────┬────────┘ └───────────┬───────────┘ └─────────────────────┘ └───────────────────────┘
|
||||
│ │ ▲
|
||||
│ └─────────────┐ │
|
||||
│ │ │
|
||||
│ ┌───────────────────────┐ │ ┌──────────┴──────────┐ ┌───────────────────────┐
|
||||
│ │ │ │ │ │ │ │
|
||||
└───────────►│ default prompt target │ └▲│ hallucination check ├──►│ parameter gathering │
|
||||
│ │ │ │ │ │
|
||||
└───────────────────────┘ └─────────────────────┘ └───────────────────────┘
|
||||
|
||||
|
||||
using https://mermaid-ascii.art/
|
||||
*/
|
||||
|
||||
|
||||
|
||||
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
match callout_context.response_handler_type {
|
||||
ResponseHandlerType::GetEmbeddings => {
|
||||
self.embeddings_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
|
||||
ResponseHandlerType::ZeroShotIntent => {
|
||||
self.zero_shot_intent_detection_resp_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::ArchFC => {
|
||||
self.arch_fc_response_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::HallucinationDetect => {
|
||||
self.hallucination_classification_resp_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::FunctionCall => {
|
||||
self.function_call_response_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::DefaultTarget => {
|
||||
self.default_target_handler(body, callout_context)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(String::from("No response body in inline HTTP request")),
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Client for StreamContext {
|
||||
type CallContext = StreamCallContext;
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue