mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
refactor more
This commit is contained in:
parent
e70f55dd5b
commit
539860efea
3 changed files with 117 additions and 113 deletions
103
crates/prompt_gateway/src/context.rs
Normal file
103
crates/prompt_gateway/src/context.rs
Normal file
|
|
@ -0,0 +1,103 @@
|
||||||
|
use common::errors::ServerError;
|
||||||
|
use common::stats::IncrementingMetric;
|
||||||
|
use proxy_wasm::traits::Context;
|
||||||
|
|
||||||
|
use crate::stream_context::{ResponseHandlerType, StreamContext};
|
||||||
|
|
||||||
|
impl Context for StreamContext {
|
||||||
|
fn on_http_call_response(
|
||||||
|
&mut self,
|
||||||
|
token_id: u32,
|
||||||
|
_num_headers: usize,
|
||||||
|
body_size: usize,
|
||||||
|
_num_trailers: usize,
|
||||||
|
) {
|
||||||
|
let callout_context = self
|
||||||
|
.callouts
|
||||||
|
.get_mut()
|
||||||
|
.remove(&token_id)
|
||||||
|
.expect("invalid token_id");
|
||||||
|
self.metrics.active_http_calls.increment(-1);
|
||||||
|
|
||||||
|
/*
|
||||||
|
state transition
|
||||||
|
|
||||||
|
graph LR
|
||||||
|
|
||||||
|
on_http_request_body --> prompt received
|
||||||
|
prompt received --> get embeddings & arch guard
|
||||||
|
arch guard --> get embeddings
|
||||||
|
get embeddings --> zeroshot intent
|
||||||
|
|
||||||
|
┌──────────────────────┐ ┌─────────────────┐ ┌────────────────┐ ┌─────────────────┐
|
||||||
|
│ │ │ │ │ │ │ │
|
||||||
|
│ on_http_request_body ├──►│ prompt received ├──►│ get embeddings ├──►│ zeroshot intent │
|
||||||
|
│ │ │ │ │ │ │ │
|
||||||
|
└──────────────────────┘ └────────┬────────┘ └────────────────┘ └─────────────────┘
|
||||||
|
│ ▲
|
||||||
|
│ │
|
||||||
|
│ │
|
||||||
|
│ ┌────────┴───────┐
|
||||||
|
│ │ │
|
||||||
|
└───────────►│ arch guard │
|
||||||
|
│ │
|
||||||
|
└────────────────┘
|
||||||
|
|
||||||
|
|
||||||
|
continue from zeroshot intent
|
||||||
|
|
||||||
|
graph LR
|
||||||
|
|
||||||
|
zeroshot intent --> arch_fc
|
||||||
|
zeroshot intent --> default prompt target
|
||||||
|
arch_fc --> developer api call & hallucination check
|
||||||
|
hallucination check --> parameter gathering & developer api call
|
||||||
|
developer api call --> resume request to llm
|
||||||
|
|
||||||
|
|
||||||
|
┌─────────────────┐ ┌───────────────────────┐ ┌─────────────────────┐ ┌───────────────────────┐
|
||||||
|
│ │ │ │ │ │ │ │
|
||||||
|
│ zeroshot intent ├──►│ arch_fc ├──►│ developer api call ├──►│ resume request to llm │
|
||||||
|
│ │ │ │ │ │ │ │
|
||||||
|
└────────┬────────┘ └───────────┬───────────┘ └─────────────────────┘ └───────────────────────┘
|
||||||
|
│ │ ▲
|
||||||
|
│ └─────────────┐ │
|
||||||
|
│ │ │
|
||||||
|
│ ┌───────────────────────┐ │ ┌──────────┴──────────┐ ┌───────────────────────┐
|
||||||
|
│ │ │ │ │ │ │ │
|
||||||
|
└───────────►│ default prompt target │ └▲│ hallucination check ├──►│ parameter gathering │
|
||||||
|
│ │ │ │ │ │
|
||||||
|
└───────────────────────┘ └─────────────────────┘ └───────────────────────┘
|
||||||
|
|
||||||
|
|
||||||
|
using https://mermaid-ascii.art/
|
||||||
|
*/
|
||||||
|
|
||||||
|
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||||
|
match callout_context.response_handler_type {
|
||||||
|
ResponseHandlerType::GetEmbeddings => {
|
||||||
|
self.embeddings_handler(body, callout_context)
|
||||||
|
}
|
||||||
|
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
|
||||||
|
ResponseHandlerType::ZeroShotIntent => {
|
||||||
|
self.zero_shot_intent_detection_resp_handler(body, callout_context)
|
||||||
|
}
|
||||||
|
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
|
||||||
|
ResponseHandlerType::HallucinationDetect => {
|
||||||
|
self.hallucination_classification_resp_handler(body, callout_context)
|
||||||
|
}
|
||||||
|
ResponseHandlerType::FunctionCall => {
|
||||||
|
self.function_call_response_handler(body, callout_context)
|
||||||
|
}
|
||||||
|
ResponseHandlerType::DefaultTarget => {
|
||||||
|
self.default_target_handler(body, callout_context)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.send_server_error(
|
||||||
|
ServerError::LogicError(String::from("No response body in inline HTTP request")),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -2,6 +2,7 @@ use filter_context::FilterContext;
|
||||||
use proxy_wasm::traits::*;
|
use proxy_wasm::traits::*;
|
||||||
use proxy_wasm::types::*;
|
use proxy_wasm::types::*;
|
||||||
|
|
||||||
|
mod context;
|
||||||
mod filter_context;
|
mod filter_context;
|
||||||
mod hallucination;
|
mod hallucination;
|
||||||
mod http_context;
|
mod http_context;
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ use common::embeddings::{
|
||||||
};
|
};
|
||||||
use common::errors::ServerError;
|
use common::errors::ServerError;
|
||||||
use common::http::{CallArgs, Client};
|
use common::http::{CallArgs, Client};
|
||||||
use common::stats::{Gauge, IncrementingMetric};
|
use common::stats::Gauge;
|
||||||
use derivative::Derivative;
|
use derivative::Derivative;
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use log::{debug, info, warn};
|
use log::{debug, info, warn};
|
||||||
|
|
@ -61,12 +61,12 @@ pub struct StreamCallContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct StreamContext {
|
pub struct StreamContext {
|
||||||
metrics: Rc<WasmMetrics>,
|
|
||||||
system_prompt: Rc<Option<String>>,
|
system_prompt: Rc<Option<String>>,
|
||||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||||
overrides: Rc<Option<Overrides>>,
|
overrides: Rc<Option<Overrides>>,
|
||||||
callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
pub metrics: Rc<WasmMetrics>,
|
||||||
|
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
||||||
pub context_id: u32,
|
pub context_id: u32,
|
||||||
pub tool_calls: Option<Vec<ToolCall>>,
|
pub tool_calls: Option<Vec<ToolCall>>,
|
||||||
pub tool_call_response: Option<String>,
|
pub tool_call_response: Option<String>,
|
||||||
|
|
@ -242,7 +242,7 @@ impl StreamContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn hallucination_classification_resp_handler(
|
pub fn hallucination_classification_resp_handler(
|
||||||
&mut self,
|
&mut self,
|
||||||
body: Vec<u8>,
|
body: Vec<u8>,
|
||||||
callout_context: StreamCallContext,
|
callout_context: StreamCallContext,
|
||||||
|
|
@ -307,7 +307,7 @@ impl StreamContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn zero_shot_intent_detection_resp_handler(
|
pub fn zero_shot_intent_detection_resp_handler(
|
||||||
&mut self,
|
&mut self,
|
||||||
body: Vec<u8>,
|
body: Vec<u8>,
|
||||||
mut callout_context: StreamCallContext,
|
mut callout_context: StreamCallContext,
|
||||||
|
|
@ -558,7 +558,11 @@ impl StreamContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn arch_fc_response_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) {
|
pub fn arch_fc_response_handler(
|
||||||
|
&mut self,
|
||||||
|
body: Vec<u8>,
|
||||||
|
mut callout_context: StreamCallContext,
|
||||||
|
) {
|
||||||
let body_str = String::from_utf8(body).unwrap();
|
let body_str = String::from_utf8(body).unwrap();
|
||||||
debug!("arch <= app response body: {}", body_str);
|
debug!("arch <= app response body: {}", body_str);
|
||||||
|
|
||||||
|
|
@ -746,7 +750,7 @@ impl StreamContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn function_call_response_handler(
|
pub fn function_call_response_handler(
|
||||||
&mut self,
|
&mut self,
|
||||||
body: Vec<u8>,
|
body: Vec<u8>,
|
||||||
callout_context: StreamCallContext,
|
callout_context: StreamCallContext,
|
||||||
|
|
@ -856,7 +860,7 @@ impl StreamContext {
|
||||||
self.resume_http_request();
|
self.resume_http_request();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||||
debug!("response received for arch guard");
|
debug!("response received for arch guard");
|
||||||
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
||||||
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
|
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
|
||||||
|
|
@ -933,7 +937,7 @@ impl StreamContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
|
pub fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||||
let prompt_target = self
|
let prompt_target = self
|
||||||
.prompt_targets
|
.prompt_targets
|
||||||
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||||
|
|
@ -1010,110 +1014,6 @@ impl StreamContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Context for StreamContext {
|
|
||||||
fn on_http_call_response(
|
|
||||||
&mut self,
|
|
||||||
token_id: u32,
|
|
||||||
_num_headers: usize,
|
|
||||||
body_size: usize,
|
|
||||||
_num_trailers: usize,
|
|
||||||
) {
|
|
||||||
let callout_context = self
|
|
||||||
.callouts
|
|
||||||
.get_mut()
|
|
||||||
.remove(&token_id)
|
|
||||||
.expect("invalid token_id");
|
|
||||||
self.metrics.active_http_calls.increment(-1);
|
|
||||||
|
|
||||||
// state transition
|
|
||||||
|
|
||||||
/*
|
|
||||||
|
|
||||||
graph LR
|
|
||||||
|
|
||||||
on_http_request_body --> prompt received
|
|
||||||
prompt received --> get embeddings & arch guard
|
|
||||||
arch guard --> get embeddings
|
|
||||||
get embeddings --> zeroshot intent
|
|
||||||
|
|
||||||
┌──────────────────────┐ ┌─────────────────┐ ┌────────────────┐ ┌─────────────────┐
|
|
||||||
│ │ │ │ │ │ │ │
|
|
||||||
│ on_http_request_body ├──►│ prompt received ├──►│ get embeddings ├──►│ zeroshot intent │
|
|
||||||
│ │ │ │ │ │ │ │
|
|
||||||
└──────────────────────┘ └────────┬────────┘ └────────────────┘ └─────────────────┘
|
|
||||||
│ ▲
|
|
||||||
│ │
|
|
||||||
│ │
|
|
||||||
│ ┌────────┴───────┐
|
|
||||||
│ │ │
|
|
||||||
└───────────►│ arch guard │
|
|
||||||
│ │
|
|
||||||
└────────────────┘
|
|
||||||
|
|
||||||
|
|
||||||
continue from zeroshot intent
|
|
||||||
|
|
||||||
graph LR
|
|
||||||
|
|
||||||
zeroshot intent --> arch_fc
|
|
||||||
zeroshot intent --> default prompt target
|
|
||||||
arch_fc --> developer api call & hallucination check
|
|
||||||
hallucination check --> parameter gathering & developer api call
|
|
||||||
developer api call --> resume request to llm
|
|
||||||
|
|
||||||
|
|
||||||
┌─────────────────┐ ┌───────────────────────┐ ┌─────────────────────┐ ┌───────────────────────┐
|
|
||||||
│ │ │ │ │ │ │ │
|
|
||||||
│ zeroshot intent ├──►│ arch_fc ├──►│ developer api call ├──►│ resume request to llm │
|
|
||||||
│ │ │ │ │ │ │ │
|
|
||||||
└────────┬────────┘ └───────────┬───────────┘ └─────────────────────┘ └───────────────────────┘
|
|
||||||
│ │ ▲
|
|
||||||
│ └─────────────┐ │
|
|
||||||
│ │ │
|
|
||||||
│ ┌───────────────────────┐ │ ┌──────────┴──────────┐ ┌───────────────────────┐
|
|
||||||
│ │ │ │ │ │ │ │
|
|
||||||
└───────────►│ default prompt target │ └▲│ hallucination check ├──►│ parameter gathering │
|
|
||||||
│ │ │ │ │ │
|
|
||||||
└───────────────────────┘ └─────────────────────┘ └───────────────────────┘
|
|
||||||
|
|
||||||
|
|
||||||
using https://mermaid-ascii.art/
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
|
||||||
match callout_context.response_handler_type {
|
|
||||||
ResponseHandlerType::GetEmbeddings => {
|
|
||||||
self.embeddings_handler(body, callout_context)
|
|
||||||
}
|
|
||||||
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
|
|
||||||
ResponseHandlerType::ZeroShotIntent => {
|
|
||||||
self.zero_shot_intent_detection_resp_handler(body, callout_context)
|
|
||||||
}
|
|
||||||
ResponseHandlerType::ArchFC => {
|
|
||||||
self.arch_fc_response_handler(body, callout_context)
|
|
||||||
}
|
|
||||||
ResponseHandlerType::HallucinationDetect => {
|
|
||||||
self.hallucination_classification_resp_handler(body, callout_context)
|
|
||||||
}
|
|
||||||
ResponseHandlerType::FunctionCall => {
|
|
||||||
self.function_call_response_handler(body, callout_context)
|
|
||||||
}
|
|
||||||
ResponseHandlerType::DefaultTarget => {
|
|
||||||
self.default_target_handler(body, callout_context)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
self.send_server_error(
|
|
||||||
ServerError::LogicError(String::from("No response body in inline HTTP request")),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Client for StreamContext {
|
impl Client for StreamContext {
|
||||||
type CallContext = StreamCallContext;
|
type CallContext = StreamCallContext;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue