Improve error handling (#23)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-07-29 12:15:26 -07:00 committed by GitHub
parent a51a467cad
commit 7ef68eccfb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 397 additions and 102 deletions

View file

@ -1,17 +1,17 @@
use common_types::{CallContext, EmbeddingRequest};
use configuration::PromptTarget;
use http::StatusCode;
use log::error;
use log::info;
use md5::Digest;
use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use serde_json::to_string;
use stats::RecordingMetric;
use std::collections::HashMap;
use std::time::Duration;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use stats::{Gauge, RecordingMetric};
use std::collections::HashMap;
use std::time::Duration;
mod common_types;
mod configuration;
@ -85,21 +85,47 @@ impl HttpContext for StreamContext {
match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(msg) => panic!("Failed to deserialize: {}", msg),
Err(msg) => {
self.send_http_response(
StatusCode::BAD_REQUEST.as_u16().into(),
vec![],
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
);
return Action::Pause;
}
},
None => panic!(
"Failed to obtain body bytes even though body_size is {}",
body_size
),
None => {
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
None,
);
error!(
"Failed to obtain body bytes even though body_size is {}",
body_size
);
return Action::Pause;
}
};
// Modify JSON payload
deserialized_body.model = String::from("gpt-3.5-turbo");
let json_string = serde_json::to_string(&deserialized_body).unwrap();
self.set_http_request_body(0, body_size, &json_string.into_bytes());
Action::Continue
match serde_json::to_string(&deserialized_body) {
Ok(json_string) => {
self.set_http_request_body(0, body_size, &json_string.into_bytes());
Action::Continue
}
Err(error) => {
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
None,
);
error!("Failed to serialize body: {}", error);
Action::Pause
}
}
}
}
@ -107,13 +133,13 @@ impl Context for StreamContext {}
#[derive(Copy, Clone)]
struct WasmMetrics {
active_http_calls: stats::Gauge,
active_http_calls: Gauge,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
WasmMetrics {
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
active_http_calls: Gauge::new(String::from("active_http_calls")),
}
}
}
@ -135,7 +161,13 @@ impl FilterContext {
}
fn process_prompt_targets(&mut self) {
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
for prompt_target in &self
.config
.as_ref()
.expect("Gateway configuration cannot be non-existent")
.prompt_config
.prompt_targets
{
for few_shot_example in &prompt_target.few_shot_examples {
info!("few_shot_example: {:?}", few_shot_example);
@ -149,8 +181,12 @@ impl FilterContext {
user: None,
};
// TODO: Handle potential errors
let json_data: String = to_string(&embeddings_input).unwrap();
let json_data: String = match serde_json::to_string(&embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing embeddings input: {}", error);
}
};
let token_id = match self.dispatch_http_call(
"embeddingserver",
@ -159,6 +195,7 @@ impl FilterContext {
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
@ -166,7 +203,7 @@ impl FilterContext {
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
panic!("Error dispatching embedding server HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
@ -181,7 +218,10 @@ impl FilterContext {
})
.is_some()
{
panic!("duplicate token_id")
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
}
self.metrics
.active_http_calls
@ -197,78 +237,65 @@ impl FilterContext {
prompt_target: PromptTarget,
) {
info!("response received for CreateEmbeddingRequest");
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
let embedding_response: CreateEmbeddingResponse =
serde_json::from_slice(&body).unwrap();
info!(
"embedding_response model: {}, vector len: {}",
embedding_response.model,
embedding_response.data[0].embedding.len()
);
let mut payload: HashMap<String, String> = HashMap::new();
payload.insert(
"prompt-target".to_string(),
to_string(&prompt_target).unwrap(),
);
let id: Option<Digest>;
match *create_embedding_request.input {
CreateEmbeddingRequestInput::String(input) => {
id = Some(md5::compute(&input));
payload.insert("input".to_string(), input);
}
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
points: vec![common_types::VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
vector: embedding_response.data[0].embedding.clone(),
}],
};
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
info!(
"create_vector_store_points: points length: {}",
embedding_response.data[0].embedding.len()
);
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store/points"),
(":authority", "qdrant"),
("content-type", "application/json"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
if self
.callouts
.insert(
token_id,
CallContext::StoreVectorEmbeddings(create_vector_store_points),
)
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
let body = match self.get_http_call_response_body(0, body_size) {
Some(body) => body,
None => {
return;
}
};
if body.is_empty() {
return;
}
let (json_data, create_vector_store_points) =
match build_qdrant_data(&body, create_embedding_request, &prompt_target) {
Ok(tuple) => tuple,
Err(error) => {
panic!(
"Error building qdrant data from embedding response {}",
error
);
}
};
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store/points"),
(":authority", "qdrant"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching qdrant HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
if self
.callouts
.insert(
token_id,
CallContext::StoreVectorEmbeddings(create_vector_store_points),
)
.is_some()
{
panic!("duplicate token_id={} in qdrant requests", token_id)
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
// TODO: @adilhafeez implement.
fn create_vector_store_points_handler(&self, body_size: usize) {
info!("response received for CreateVectorStorePoints");
if let Some(body) = self.get_http_call_response_body(0, body_size) {
@ -294,7 +321,6 @@ impl Context for FilterContext {
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
match callout_data {
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
create_embedding_request,
@ -313,7 +339,12 @@ impl Context for FilterContext {
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
self.config = match serde_yaml::from_slice(&config_bytes) {
Ok(config) => config,
Err(error) => {
panic!("Failed to deserialize plugin configuration: {}", error);
}
};
info!("on_configure: plugin configuration loaded");
}
true
@ -339,3 +370,51 @@ impl RootContext for FilterContext {
self.set_tick_period(Duration::from_secs(0));
}
}
fn build_qdrant_data(
embedding_data: &[u8],
create_embedding_request: CreateEmbeddingRequest,
prompt_target: &PromptTarget,
) -> Result<(String, common_types::StoreVectorEmbeddingsRequest), serde_json::Error> {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(embedding_data) {
Ok(embedding_response) => embedding_response,
Err(error) => {
panic!("Failed to deserialize embedding response: {}", error);
}
};
info!(
"embedding_response model: {}, vector len: {}",
embedding_response.model,
embedding_response.data[0].embedding.len()
);
let mut payload: HashMap<String, String> = HashMap::new();
payload.insert(
"prompt-target".to_string(),
serde_json::to_string(&prompt_target)?,
);
let id: Option<Digest>;
match *create_embedding_request.input {
CreateEmbeddingRequestInput::String(ref input) => {
id = Some(md5::compute(input));
payload.insert("input".to_string(), input.clone());
}
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
points: vec![common_types::VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
vector: embedding_response.data[0].embedding.clone(),
}],
};
let json_data = serde_json::to_string(&create_vector_store_points)?;
info!(
"create_vector_store_points: points length: {}",
embedding_response.data[0].embedding.len()
);
Ok((json_data, create_vector_store_points))
}

View file

@ -1,14 +1,15 @@
use log::error;
use proxy_wasm::hostcalls;
use proxy_wasm::types::*;
#[allow(unused)]
pub trait Metric {
fn id(&self) -> u32;
fn value(&self) -> u64 {
fn value(&self) -> Result<u64, String> {
match hostcalls::get_metric(self.id()) {
Ok(value) => value,
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
Err(err) => panic!("unexpected status: {:?}", err),
Ok(value) => Ok(value),
Err(Status::NotFound) => Err(format!("metric not found: {}", self.id())),
Err(err) => Err(format!("unexpected status: {:?}", err)),
}
}
}
@ -17,9 +18,8 @@ pub trait Metric {
pub trait IncrementingMetric: Metric {
fn increment(&self, offset: i64) {
match hostcalls::increment_metric(self.id(), offset) {
Ok(data) => data,
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
Err(err) => panic!("unexpected status: {:?}", err),
Ok(_) => (),
Err(err) => error!("error incrementing metric: {:?}", err),
}
}
}
@ -27,9 +27,8 @@ pub trait IncrementingMetric: Metric {
pub trait RecordingMetric: Metric {
fn record(&self, value: u64) {
match hostcalls::record_metric(self.id(), value) {
Ok(data) => data,
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
Err(err) => panic!("unexpected status: {:?}", err),
Ok(_) => (),
Err(err) => error!("error recording metric: {:?}", err),
}
}
}