Improve error handling (#23)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-07-29 12:15:26 -07:00 committed by GitHub
parent a51a467cad
commit 7ef68eccfb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 397 additions and 102 deletions

128
envoyfilter/Cargo.lock generated
View file

@ -511,6 +511,21 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.30"
@ -518,6 +533,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
@ -526,6 +542,23 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d"
[[package]]
name = "futures-executor"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
[[package]]
name = "futures-sink"
version = "0.3.30"
@ -544,10 +577,15 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
]
[[package]]
@ -822,6 +860,7 @@ dependencies = [
name = "intelligent-prompt-gateway"
version = "0.1.0"
dependencies = [
"http",
"log",
"md5",
"open-message-format",
@ -830,6 +869,7 @@ dependencies = [
"serde",
"serde_json",
"serde_yaml",
"serial_test",
]
[[package]]
@ -925,6 +965,16 @@ version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
[[package]]
name = "lock_api"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
dependencies = [
"autocfg",
"scopeguard",
]
[[package]]
name = "log"
version = "0.4.22"
@ -1112,6 +1162,29 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "parking_lot"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
dependencies = [
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.9.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
dependencies = [
"cfg-if 1.0.0",
"libc",
"redox_syscall",
"smallvec",
"windows-targets 0.52.6",
]
[[package]]
name = "paste"
version = "1.0.15"
@ -1293,6 +1366,15 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "redox_syscall"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4"
dependencies = [
"bitflags 2.6.0",
]
[[package]]
name = "redox_users"
version = "0.4.5"
@ -1447,6 +1529,15 @@ version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "scc"
version = "2.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fadf67e3cf23f8b11a6c8c48a16cb2437381503615acd91094ec7b4686a5a53"
dependencies = [
"sdd",
]
[[package]]
name = "schannel"
version = "0.1.23"
@ -1456,6 +1547,18 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "sdd"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85f05a494052771fc5bd0619742363b5e24e5ad72ab3111ec2e27925b8edc5f3"
[[package]]
name = "security-framework"
version = "2.11.1"
@ -1550,6 +1653,31 @@ dependencies = [
"unsafe-libyaml",
]
[[package]]
name = "serial_test"
version = "3.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b4b487fe2acf240a021cf57c6b2b4903b1e78ca0ecd862a71b71d2a51fed77d"
dependencies = [
"futures",
"log",
"once_cell",
"parking_lot",
"scc",
"serial_test_derive",
]
[[package]]
name = "serial_test_derive"
version = "3.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.72",
]
[[package]]
name = "sha2"
version = "0.10.8"

View file

@ -15,6 +15,8 @@ serde_yaml = "0.9.34"
serde_json = "1.0"
md5 = "0.7.0"
open-message-format = { path = "../open-message-format/clients/omf-rust" }
http = "1.1.0"
[dev-dependencies]
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" }
serial_test = "3.1.1"

View file

@ -1,17 +1,17 @@
use common_types::{CallContext, EmbeddingRequest};
use configuration::PromptTarget;
use http::StatusCode;
use log::error;
use log::info;
use md5::Digest;
use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use serde_json::to_string;
use stats::RecordingMetric;
use std::collections::HashMap;
use std::time::Duration;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use stats::{Gauge, RecordingMetric};
use std::collections::HashMap;
use std::time::Duration;
mod common_types;
mod configuration;
@ -85,21 +85,47 @@ impl HttpContext for StreamContext {
match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(msg) => panic!("Failed to deserialize: {}", msg),
Err(msg) => {
self.send_http_response(
StatusCode::BAD_REQUEST.as_u16().into(),
vec![],
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
);
return Action::Pause;
}
},
None => panic!(
"Failed to obtain body bytes even though body_size is {}",
body_size
),
None => {
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
None,
);
error!(
"Failed to obtain body bytes even though body_size is {}",
body_size
);
return Action::Pause;
}
};
// Modify JSON payload
deserialized_body.model = String::from("gpt-3.5-turbo");
let json_string = serde_json::to_string(&deserialized_body).unwrap();
self.set_http_request_body(0, body_size, &json_string.into_bytes());
Action::Continue
match serde_json::to_string(&deserialized_body) {
Ok(json_string) => {
self.set_http_request_body(0, body_size, &json_string.into_bytes());
Action::Continue
}
Err(error) => {
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
None,
);
error!("Failed to serialize body: {}", error);
Action::Pause
}
}
}
}
@ -107,13 +133,13 @@ impl Context for StreamContext {}
#[derive(Copy, Clone)]
struct WasmMetrics {
active_http_calls: stats::Gauge,
active_http_calls: Gauge,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
WasmMetrics {
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
active_http_calls: Gauge::new(String::from("active_http_calls")),
}
}
}
@ -135,7 +161,13 @@ impl FilterContext {
}
fn process_prompt_targets(&mut self) {
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
for prompt_target in &self
.config
.as_ref()
.expect("Gateway configuration cannot be non-existent")
.prompt_config
.prompt_targets
{
for few_shot_example in &prompt_target.few_shot_examples {
info!("few_shot_example: {:?}", few_shot_example);
@ -149,8 +181,12 @@ impl FilterContext {
user: None,
};
// TODO: Handle potential errors
let json_data: String = to_string(&embeddings_input).unwrap();
let json_data: String = match serde_json::to_string(&embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing embeddings input: {}", error);
}
};
let token_id = match self.dispatch_http_call(
"embeddingserver",
@ -159,6 +195,7 @@ impl FilterContext {
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
@ -166,7 +203,7 @@ impl FilterContext {
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
panic!("Error dispatching embedding server HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
@ -181,7 +218,10 @@ impl FilterContext {
})
.is_some()
{
panic!("duplicate token_id")
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
}
self.metrics
.active_http_calls
@ -197,78 +237,65 @@ impl FilterContext {
prompt_target: PromptTarget,
) {
info!("response received for CreateEmbeddingRequest");
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
let embedding_response: CreateEmbeddingResponse =
serde_json::from_slice(&body).unwrap();
info!(
"embedding_response model: {}, vector len: {}",
embedding_response.model,
embedding_response.data[0].embedding.len()
);
let mut payload: HashMap<String, String> = HashMap::new();
payload.insert(
"prompt-target".to_string(),
to_string(&prompt_target).unwrap(),
);
let id: Option<Digest>;
match *create_embedding_request.input {
CreateEmbeddingRequestInput::String(input) => {
id = Some(md5::compute(&input));
payload.insert("input".to_string(), input);
}
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
points: vec![common_types::VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
vector: embedding_response.data[0].embedding.clone(),
}],
};
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
info!(
"create_vector_store_points: points length: {}",
embedding_response.data[0].embedding.len()
);
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store/points"),
(":authority", "qdrant"),
("content-type", "application/json"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
if self
.callouts
.insert(
token_id,
CallContext::StoreVectorEmbeddings(create_vector_store_points),
)
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
let body = match self.get_http_call_response_body(0, body_size) {
Some(body) => body,
None => {
return;
}
};
if body.is_empty() {
return;
}
let (json_data, create_vector_store_points) =
match build_qdrant_data(&body, create_embedding_request, &prompt_target) {
Ok(tuple) => tuple,
Err(error) => {
panic!(
"Error building qdrant data from embedding response {}",
error
);
}
};
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store/points"),
(":authority", "qdrant"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching qdrant HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
if self
.callouts
.insert(
token_id,
CallContext::StoreVectorEmbeddings(create_vector_store_points),
)
.is_some()
{
panic!("duplicate token_id={} in qdrant requests", token_id)
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
// TODO: @adilhafeez implement.
fn create_vector_store_points_handler(&self, body_size: usize) {
info!("response received for CreateVectorStorePoints");
if let Some(body) = self.get_http_call_response_body(0, body_size) {
@ -294,7 +321,6 @@ impl Context for FilterContext {
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
match callout_data {
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
create_embedding_request,
@ -313,7 +339,12 @@ impl Context for FilterContext {
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
self.config = match serde_yaml::from_slice(&config_bytes) {
Ok(config) => config,
Err(error) => {
panic!("Failed to deserialize plugin configuration: {}", error);
}
};
info!("on_configure: plugin configuration loaded");
}
true
@ -339,3 +370,51 @@ impl RootContext for FilterContext {
self.set_tick_period(Duration::from_secs(0));
}
}
fn build_qdrant_data(
embedding_data: &[u8],
create_embedding_request: CreateEmbeddingRequest,
prompt_target: &PromptTarget,
) -> Result<(String, common_types::StoreVectorEmbeddingsRequest), serde_json::Error> {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(embedding_data) {
Ok(embedding_response) => embedding_response,
Err(error) => {
panic!("Failed to deserialize embedding response: {}", error);
}
};
info!(
"embedding_response model: {}, vector len: {}",
embedding_response.model,
embedding_response.data[0].embedding.len()
);
let mut payload: HashMap<String, String> = HashMap::new();
payload.insert(
"prompt-target".to_string(),
serde_json::to_string(&prompt_target)?,
);
let id: Option<Digest>;
match *create_embedding_request.input {
CreateEmbeddingRequestInput::String(ref input) => {
id = Some(md5::compute(input));
payload.insert("input".to_string(), input.clone());
}
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
points: vec![common_types::VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
vector: embedding_response.data[0].embedding.clone(),
}],
};
let json_data = serde_json::to_string(&create_vector_store_points)?;
info!(
"create_vector_store_points: points length: {}",
embedding_response.data[0].embedding.len()
);
Ok((json_data, create_vector_store_points))
}

View file

@ -1,14 +1,15 @@
use log::error;
use proxy_wasm::hostcalls;
use proxy_wasm::types::*;
#[allow(unused)]
pub trait Metric {
fn id(&self) -> u32;
fn value(&self) -> u64 {
fn value(&self) -> Result<u64, String> {
match hostcalls::get_metric(self.id()) {
Ok(value) => value,
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
Err(err) => panic!("unexpected status: {:?}", err),
Ok(value) => Ok(value),
Err(Status::NotFound) => Err(format!("metric not found: {}", self.id())),
Err(err) => Err(format!("unexpected status: {:?}", err)),
}
}
}
@ -17,9 +18,8 @@ pub trait Metric {
pub trait IncrementingMetric: Metric {
fn increment(&self, offset: i64) {
match hostcalls::increment_metric(self.id(), offset) {
Ok(data) => data,
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
Err(err) => panic!("unexpected status: {:?}", err),
Ok(_) => (),
Err(err) => error!("error incrementing metric: {:?}", err),
}
}
}
@ -27,9 +27,8 @@ pub trait IncrementingMetric: Metric {
pub trait RecordingMetric: Metric {
fn record(&self, value: u64) {
match hostcalls::record_metric(self.id(), value) {
Ok(data) => data,
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
Err(err) => panic!("unexpected status: {:?}", err),
Ok(_) => (),
Err(err) => error!("error recording metric: {:?}", err),
}
}
}

View file

@ -1,5 +1,7 @@
use http::StatusCode;
use proxy_wasm_test_framework::tester;
use proxy_wasm_test_framework::types::{Action, BufferType, MapType, MetricType, ReturnType};
use serial_test::serial;
use std::path::Path;
fn wasm_module() -> String {
@ -12,7 +14,8 @@ fn wasm_module() -> String {
}
#[test]
fn request_to_open_ai_chat_completions() {
#[serial]
fn successful_request_to_open_ai_chat_completions() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
@ -90,3 +93,87 @@ fn request_to_open_ai_chat_completions() {
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
#[test]
#[serial]
fn bad_request_to_open_ai_chat_completions() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let root_context = 1;
module
.call_proxy_on_context_create(root_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup HTTP Stream
let http_context = 2;
module
.call_proxy_on_context_create(http_context, root_context)
.execute_and_expect(ReturnType::None)
.unwrap();
// Request Headers
module
.call_proxy_on_request_headers(http_context, 0, false)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
.returning(Some("api.openai.com"))
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("content-length"),
Some(""),
)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
.returning(Some("/llmrouting"))
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some(":path"),
Some("/v1/chat/completions"),
)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
// Request Body
let incomplete_chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]\
}";
module
.call_proxy_on_request_body(
http_context,
incomplete_chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,
None,
None,
)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}