Code refactor and some improvements - see description (#194)

This commit is contained in:
Adil Hafeez 2024-10-18 12:53:44 -07:00 committed by GitHub
parent aa30353c85
commit c6ba28dfcc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 100 additions and 115 deletions

View file

@ -1,4 +1,4 @@
use crate::llm_stream_context::LlmGatewayStreamContext;
use crate::stream_context::StreamContext;
use common::configuration::Configuration;
use common::http::Client;
use common::llm_providers::LlmProviders;
@ -28,19 +28,19 @@ impl WasmMetrics {
}
#[derive(Debug)]
pub struct FilterCallContext {}
pub struct CallContext {}
#[derive(Debug)]
pub struct LlmGatewayFilterContext {
pub struct FilterContext {
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, FilterCallContext>>,
callouts: RefCell<HashMap<u32, CallContext>>,
llm_providers: Option<Rc<LlmProviders>>,
}
impl LlmGatewayFilterContext {
pub fn new() -> LlmGatewayFilterContext {
LlmGatewayFilterContext {
impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()),
llm_providers: None,
@ -48,8 +48,8 @@ impl LlmGatewayFilterContext {
}
}
impl Client for LlmGatewayFilterContext {
type CallContext = FilterCallContext;
impl Client for FilterContext {
type CallContext = CallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
&self.callouts
@ -60,10 +60,10 @@ impl Client for LlmGatewayFilterContext {
}
}
impl Context for LlmGatewayFilterContext {}
impl Context for FilterContext {}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for LlmGatewayFilterContext {
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
let config_bytes = self
.get_plugin_configuration()
@ -90,8 +90,7 @@ impl RootContext for LlmGatewayFilterContext {
context_id
);
// No StreamContext can be created until the Embedding Store is fully initialized.
Some(Box::new(LlmGatewayStreamContext::new(
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
Rc::clone(

View file

@ -1,13 +1,13 @@
use llm_filter_context::LlmGatewayFilterContext;
use filter_context::FilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod llm_filter_context;
mod llm_stream_context;
mod filter_context;
mod stream_context;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(LlmGatewayFilterContext::new())
Box::new(FilterContext::new())
});
}}

View file

@ -1,4 +1,4 @@
use crate::llm_filter_context::WasmMetrics;
use crate::filter_context::WasmMetrics;
use common::common_types::open_ai::{
ArchState, ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse,
Message, ToolCall, ToolCallState,
@ -8,6 +8,7 @@ use common::consts::{
ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, CHAT_COMPLETIONS_PATH,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, USER_ROLE,
};
use common::errors::ServerError;
use common::llm_providers::LlmProviders;
use common::ratelimit::Header;
use common::{ratelimit, routing, tokenizer};
@ -22,25 +23,12 @@ use std::rc::Rc;
use common::stats::IncrementingMetric;
#[derive(thiserror::Error, Debug)]
pub enum ServerError {
#[error(transparent)]
Deserialization(serde_json::Error),
#[error("{0}")]
LogicError(String),
#[error(transparent)]
ExceededRatelimit(ratelimit::Error),
#[error("{why}")]
BadRequest { why: String },
}
pub struct LlmGatewayStreamContext {
pub struct StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
tool_calls: Option<Vec<ToolCall>>,
tool_call_response: Option<String>,
arch_state: Option<Vec<ArchState>>,
request_body_size: usize,
ratelimit_selector: Option<Header>,
streaming_response: bool,
user_prompt: Option<Message>,
@ -52,17 +40,15 @@ pub struct LlmGatewayStreamContext {
request_id: Option<String>,
}
impl LlmGatewayStreamContext {
#[allow(clippy::too_many_arguments)]
impl StreamContext {
pub fn new(context_id: u32, metrics: Rc<WasmMetrics>, llm_providers: Rc<LlmProviders>) -> Self {
LlmGatewayStreamContext {
StreamContext {
context_id,
metrics,
chat_completions_request: None,
tool_calls: None,
tool_call_response: None,
arch_state: None,
request_body_size: 0,
ratelimit_selector: None,
streaming_response: false,
user_prompt: None,
@ -160,7 +146,7 @@ impl LlmGatewayStreamContext {
}
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for LlmGatewayStreamContext {
impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
@ -198,8 +184,6 @@ impl HttpContext for LlmGatewayStreamContext {
return Action::Continue;
}
self.request_body_size = body_size;
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest =
@ -225,7 +209,6 @@ impl HttpContext for LlmGatewayStreamContext {
return Action::Pause;
}
};
self.is_chat_completions_request = true;
// remove metadata from the request body
deserialized_body.metadata = None;
@ -418,4 +401,4 @@ impl HttpContext for LlmGatewayStreamContext {
}
}
impl Context for LlmGatewayStreamContext {}
impl Context for StreamContext {}