Code refactor and some improvements - see description (#194)

This commit is contained in:
Adil Hafeez 2024-10-18 12:53:44 -07:00 committed by GitHub
parent aa30353c85
commit c6ba28dfcc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 100 additions and 115 deletions

View file

@ -228,6 +228,7 @@ dependencies = [
"proxy-wasm",
"rand",
"serde",
"serde_json",
"serde_yaml",
"thiserror",
"tiktoken-rs",

View file

@ -1,6 +1,6 @@
use crate::prompt_stream_context::PromptStreamContext;
use crate::stream_context::StreamContext;
use common::common_types::EmbeddingType;
use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget};
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget};
use common::consts::ARCH_INTERNAL_CLUSTER_NAME;
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use common::consts::DEFAULT_EMBEDDING_MODEL;
@ -10,7 +10,6 @@ use common::embeddings::{
};
use common::http::CallArgs;
use common::http::Client;
use common::llm_providers::LlmProviders;
use common::stats::Gauge;
use common::stats::IncrementingMetric;
use log::debug;
@ -45,31 +44,27 @@ pub struct FilterCallContext {
}
#[derive(Debug)]
pub struct PromptGatewayFilterContext {
pub struct FilterContext {
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, FilterCallContext>>,
overrides: Rc<Option<Overrides>>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
mode: GatewayMode,
prompt_guards: Rc<PromptGuards>,
llm_providers: Option<Rc<LlmProviders>>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
temp_embeddings_store: EmbeddingsStore,
}
impl PromptGatewayFilterContext {
pub fn new() -> PromptGatewayFilterContext {
PromptGatewayFilterContext {
impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()),
system_prompt: Rc::new(None),
prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()),
mode: GatewayMode::Prompt,
llm_providers: None,
embeddings_store: Some(Rc::new(HashMap::new())),
temp_embeddings_store: HashMap::new(),
}
@ -117,7 +112,7 @@ impl PromptGatewayFilterContext {
Duration::from_secs(60),
);
let call_context = crate::prompt_filter_context::FilterCallContext {
let call_context = crate::filter_context::FilterCallContext {
prompt_target_name: String::from(prompt_target_name),
embedding_type,
};
@ -194,7 +189,7 @@ impl PromptGatewayFilterContext {
}
}
impl Client for PromptGatewayFilterContext {
impl Client for FilterContext {
type CallContext = FilterCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
@ -206,7 +201,7 @@ impl Client for PromptGatewayFilterContext {
}
}
impl Context for PromptGatewayFilterContext {
impl Context for FilterContext {
fn on_http_call_response(
&mut self,
token_id: u32,
@ -235,7 +230,7 @@ impl Context for PromptGatewayFilterContext {
}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for PromptGatewayFilterContext {
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
let config_bytes = self
.get_plugin_configuration()
@ -254,17 +249,11 @@ impl RootContext for PromptGatewayFilterContext {
}
self.system_prompt = Rc::new(config.system_prompt);
self.prompt_targets = Rc::new(prompt_targets);
self.mode = config.mode.unwrap_or_default();
if let Some(prompt_guards) = config.prompt_guards {
self.prompt_guards = Rc::new(prompt_guards)
}
match config.llm_providers.try_into() {
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
Err(err) => panic!("{err}"),
}
true
}
@ -274,12 +263,11 @@ impl RootContext for PromptGatewayFilterContext {
context_id
);
// No StreamContext can be created until the Embedding Store is fully initialized.
let embedding_store = match self.mode {
GatewayMode::Llm => None,
GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())),
let embedding_store = match self.embeddings_store.as_ref() {
None => return None,
Some(store) => Some(Rc::clone(store)),
};
Some(Box::new(PromptStreamContext::new(
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
Rc::clone(&self.system_prompt),
@ -300,11 +288,8 @@ impl RootContext for PromptGatewayFilterContext {
}
fn on_tick(&mut self) {
debug!("starting up arch filter in mode: {:?}", self.mode);
if self.mode == GatewayMode::Prompt {
self.process_prompt_targets();
}
debug!("starting up arch filter in mode: prompt gateway mode");
self.process_prompt_targets();
self.set_tick_period(Duration::from_secs(0));
}
}

View file

@ -1,13 +1,13 @@
use prompt_filter_context::PromptGatewayFilterContext;
use filter_context::FilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod prompt_filter_context;
mod prompt_stream_context;
mod filter_context;
mod stream_context;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(PromptGatewayFilterContext::new())
Box::new(FilterContext::new())
});
}}

View file

@ -1,4 +1,4 @@
use crate::prompt_filter_context::{EmbeddingsStore, WasmMetrics};
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use acap::cos;
use common::common_types::open_ai::{
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
@ -21,7 +21,8 @@ use common::consts::{
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use common::http::{CallArgs, Client, ClientError};
use common::errors::ClientError;
use common::http::{CallArgs, Client};
use common::stats::Gauge;
use http::StatusCode;
use log::{debug, info, warn};
@ -81,7 +82,7 @@ pub enum ServerError {
NoMessagesFound { why: String },
}
pub struct PromptStreamContext {
pub struct StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
system_prompt: Rc<Option<String>>,
@ -102,8 +103,7 @@ pub struct PromptStreamContext {
request_id: Option<String>,
}
impl PromptStreamContext {
#[allow(clippy::too_many_arguments)]
impl StreamContext {
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
@ -113,7 +113,7 @@ impl PromptStreamContext {
overrides: Rc<Option<Overrides>>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
) -> Self {
PromptStreamContext {
StreamContext {
context_id,
metrics,
system_prompt,
@ -1031,7 +1031,7 @@ impl PromptStreamContext {
}
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for PromptStreamContext {
impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
@ -1094,7 +1094,6 @@ impl HttpContext for PromptStreamContext {
return Action::Pause;
}
};
self.is_chat_completions_request = true;
self.arch_state = match deserialized_body.metadata {
Some(ref metadata) => {
@ -1346,7 +1345,7 @@ impl HttpContext for PromptStreamContext {
}
}
impl Context for PromptStreamContext {
impl Context for StreamContext {
fn on_http_call_response(
&mut self,
token_id: u32,
@ -1392,7 +1391,7 @@ impl Context for PromptStreamContext {
}
}
impl Client for PromptStreamContext {
impl Client for StreamContext {
type CallContext = StreamCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {