mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 04:42:49 +02:00
Improve error handling (#23)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
a51a467cad
commit
7ef68eccfb
5 changed files with 397 additions and 102 deletions
|
|
@ -1,17 +1,17 @@
|
|||
use common_types::{CallContext, EmbeddingRequest};
|
||||
use configuration::PromptTarget;
|
||||
use http::StatusCode;
|
||||
use log::error;
|
||||
use log::info;
|
||||
use md5::Digest;
|
||||
use open_message_format::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use serde_json::to_string;
|
||||
use stats::RecordingMetric;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use stats::{Gauge, RecordingMetric};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
mod common_types;
|
||||
mod configuration;
|
||||
|
|
@ -85,21 +85,47 @@ impl HttpContext for StreamContext {
|
|||
match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(msg) => panic!("Failed to deserialize: {}", msg),
|
||||
Err(msg) => {
|
||||
self.send_http_response(
|
||||
StatusCode::BAD_REQUEST.as_u16().into(),
|
||||
vec![],
|
||||
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
},
|
||||
None => panic!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
),
|
||||
None => {
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
error!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
// Modify JSON payload
|
||||
deserialized_body.model = String::from("gpt-3.5-turbo");
|
||||
let json_string = serde_json::to_string(&deserialized_body).unwrap();
|
||||
|
||||
self.set_http_request_body(0, body_size, &json_string.into_bytes());
|
||||
|
||||
Action::Continue
|
||||
match serde_json::to_string(&deserialized_body) {
|
||||
Ok(json_string) => {
|
||||
self.set_http_request_body(0, body_size, &json_string.into_bytes());
|
||||
Action::Continue
|
||||
}
|
||||
Err(error) => {
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
error!("Failed to serialize body: {}", error);
|
||||
Action::Pause
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -107,13 +133,13 @@ impl Context for StreamContext {}
|
|||
|
||||
#[derive(Copy, Clone)]
|
||||
struct WasmMetrics {
|
||||
active_http_calls: stats::Gauge,
|
||||
active_http_calls: Gauge,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
WasmMetrics {
|
||||
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -135,7 +161,13 @@ impl FilterContext {
|
|||
}
|
||||
|
||||
fn process_prompt_targets(&mut self) {
|
||||
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
|
||||
for prompt_target in &self
|
||||
.config
|
||||
.as_ref()
|
||||
.expect("Gateway configuration cannot be non-existent")
|
||||
.prompt_config
|
||||
.prompt_targets
|
||||
{
|
||||
for few_shot_example in &prompt_target.few_shot_examples {
|
||||
info!("few_shot_example: {:?}", few_shot_example);
|
||||
|
||||
|
|
@ -149,8 +181,12 @@ impl FilterContext {
|
|||
user: None,
|
||||
};
|
||||
|
||||
// TODO: Handle potential errors
|
||||
let json_data: String = to_string(&embeddings_input).unwrap();
|
||||
let json_data: String = match serde_json::to_string(&embeddings_input) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing embeddings input: {}", error);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
|
|
@ -159,6 +195,7 @@ impl FilterContext {
|
|||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
|
|
@ -166,7 +203,7 @@ impl FilterContext {
|
|||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
panic!("Error dispatching embedding server HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
|
|
@ -181,7 +218,10 @@ impl FilterContext {
|
|||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
panic!(
|
||||
"duplicate token_id={} in embedding server requests",
|
||||
token_id
|
||||
)
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
|
|
@ -197,78 +237,65 @@ impl FilterContext {
|
|||
prompt_target: PromptTarget,
|
||||
) {
|
||||
info!("response received for CreateEmbeddingRequest");
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
let embedding_response: CreateEmbeddingResponse =
|
||||
serde_json::from_slice(&body).unwrap();
|
||||
info!(
|
||||
"embedding_response model: {}, vector len: {}",
|
||||
embedding_response.model,
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
payload.insert(
|
||||
"prompt-target".to_string(),
|
||||
to_string(&prompt_target).unwrap(),
|
||||
);
|
||||
let id: Option<Digest>;
|
||||
match *create_embedding_request.input {
|
||||
CreateEmbeddingRequestInput::String(input) => {
|
||||
id = Some(md5::compute(&input));
|
||||
payload.insert("input".to_string(), input);
|
||||
}
|
||||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
|
||||
points: vec![common_types::VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
}],
|
||||
};
|
||||
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
|
||||
info!(
|
||||
"create_vector_store_points: points length: {}",
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store/points"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::StoreVectorEmbeddings(create_vector_store_points),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
let body = match self.get_http_call_response_body(0, body_size) {
|
||||
Some(body) => body,
|
||||
None => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if body.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let (json_data, create_vector_store_points) =
|
||||
match build_qdrant_data(&body, create_embedding_request, &prompt_target) {
|
||||
Ok(tuple) => tuple,
|
||||
Err(error) => {
|
||||
panic!(
|
||||
"Error building qdrant data from embedding response {}",
|
||||
error
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store/points"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching qdrant HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::StoreVectorEmbeddings(create_vector_store_points),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id={} in qdrant requests", token_id)
|
||||
}
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
|
||||
// TODO: @adilhafeez implement.
|
||||
fn create_vector_store_points_handler(&self, body_size: usize) {
|
||||
info!("response received for CreateVectorStorePoints");
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
|
|
@ -294,7 +321,6 @@ impl Context for FilterContext {
|
|||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
|
||||
match callout_data {
|
||||
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
|
||||
create_embedding_request,
|
||||
|
|
@ -313,7 +339,12 @@ impl Context for FilterContext {
|
|||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
|
||||
self.config = match serde_yaml::from_slice(&config_bytes) {
|
||||
Ok(config) => config,
|
||||
Err(error) => {
|
||||
panic!("Failed to deserialize plugin configuration: {}", error);
|
||||
}
|
||||
};
|
||||
info!("on_configure: plugin configuration loaded");
|
||||
}
|
||||
true
|
||||
|
|
@ -339,3 +370,51 @@ impl RootContext for FilterContext {
|
|||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
}
|
||||
|
||||
fn build_qdrant_data(
|
||||
embedding_data: &[u8],
|
||||
create_embedding_request: CreateEmbeddingRequest,
|
||||
prompt_target: &PromptTarget,
|
||||
) -> Result<(String, common_types::StoreVectorEmbeddingsRequest), serde_json::Error> {
|
||||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(embedding_data) {
|
||||
Ok(embedding_response) => embedding_response,
|
||||
Err(error) => {
|
||||
panic!("Failed to deserialize embedding response: {}", error);
|
||||
}
|
||||
};
|
||||
info!(
|
||||
"embedding_response model: {}, vector len: {}",
|
||||
embedding_response.model,
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
payload.insert(
|
||||
"prompt-target".to_string(),
|
||||
serde_json::to_string(&prompt_target)?,
|
||||
);
|
||||
|
||||
let id: Option<Digest>;
|
||||
match *create_embedding_request.input {
|
||||
CreateEmbeddingRequestInput::String(ref input) => {
|
||||
id = Some(md5::compute(input));
|
||||
payload.insert("input".to_string(), input.clone());
|
||||
}
|
||||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
|
||||
points: vec![common_types::VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
}],
|
||||
};
|
||||
let json_data = serde_json::to_string(&create_vector_store_points)?;
|
||||
info!(
|
||||
"create_vector_store_points: points length: {}",
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
|
||||
Ok((json_data, create_vector_store_points))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
use log::error;
|
||||
use proxy_wasm::hostcalls;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait Metric {
|
||||
fn id(&self) -> u32;
|
||||
fn value(&self) -> u64 {
|
||||
fn value(&self) -> Result<u64, String> {
|
||||
match hostcalls::get_metric(self.id()) {
|
||||
Ok(value) => value,
|
||||
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
|
||||
Err(err) => panic!("unexpected status: {:?}", err),
|
||||
Ok(value) => Ok(value),
|
||||
Err(Status::NotFound) => Err(format!("metric not found: {}", self.id())),
|
||||
Err(err) => Err(format!("unexpected status: {:?}", err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -17,9 +18,8 @@ pub trait Metric {
|
|||
pub trait IncrementingMetric: Metric {
|
||||
fn increment(&self, offset: i64) {
|
||||
match hostcalls::increment_metric(self.id(), offset) {
|
||||
Ok(data) => data,
|
||||
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
|
||||
Err(err) => panic!("unexpected status: {:?}", err),
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("error incrementing metric: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -27,9 +27,8 @@ pub trait IncrementingMetric: Metric {
|
|||
pub trait RecordingMetric: Metric {
|
||||
fn record(&self, value: u64) {
|
||||
match hostcalls::record_metric(self.id(), value) {
|
||||
Ok(data) => data,
|
||||
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
|
||||
Err(err) => panic!("unexpected status: {:?}", err),
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("error recording metric: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue