mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Improve error handling (#23)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
a51a467cad
commit
7ef68eccfb
5 changed files with 397 additions and 102 deletions
128
envoyfilter/Cargo.lock
generated
128
envoyfilter/Cargo.lock
generated
|
|
@ -511,6 +511,21 @@ dependencies = [
|
|||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-executor",
|
||||
"futures-io",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.30"
|
||||
|
|
@ -518,6 +533,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -526,6 +542,23 @@ version = "0.3.30"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.30"
|
||||
|
|
@ -544,10 +577,15 @@ version = "0.3.30"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -822,6 +860,7 @@ dependencies = [
|
|||
name = "intelligent-prompt-gateway"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"http",
|
||||
"log",
|
||||
"md5",
|
||||
"open-message-format",
|
||||
|
|
@ -830,6 +869,7 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
"serial_test",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -925,6 +965,16 @@ version = "0.4.14"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.22"
|
||||
|
|
@ -1112,6 +1162,29 @@ dependencies = [
|
|||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot_core"
|
||||
version = "0.9.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"smallvec",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.15"
|
||||
|
|
@ -1293,6 +1366,15 @@ dependencies = [
|
|||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4"
|
||||
dependencies = [
|
||||
"bitflags 2.6.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_users"
|
||||
version = "0.4.5"
|
||||
|
|
@ -1447,6 +1529,15 @@ version = "1.0.18"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
|
||||
|
||||
[[package]]
|
||||
name = "scc"
|
||||
version = "2.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fadf67e3cf23f8b11a6c8c48a16cb2437381503615acd91094ec7b4686a5a53"
|
||||
dependencies = [
|
||||
"sdd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.23"
|
||||
|
|
@ -1456,6 +1547,18 @@ dependencies = [
|
|||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "sdd"
|
||||
version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85f05a494052771fc5bd0619742363b5e24e5ad72ab3111ec2e27925b8edc5f3"
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.11.1"
|
||||
|
|
@ -1550,6 +1653,31 @@ dependencies = [
|
|||
"unsafe-libyaml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serial_test"
|
||||
version = "3.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b4b487fe2acf240a021cf57c6b2b4903b1e78ca0ecd862a71b71d2a51fed77d"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"log",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"scc",
|
||||
"serial_test_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serial_test_derive"
|
||||
version = "3.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.72",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.8"
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ serde_yaml = "0.9.34"
|
|||
serde_json = "1.0"
|
||||
md5 = "0.7.0"
|
||||
open-message-format = { path = "../open-message-format/clients/omf-rust" }
|
||||
http = "1.1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" }
|
||||
serial_test = "3.1.1"
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
use common_types::{CallContext, EmbeddingRequest};
|
||||
use configuration::PromptTarget;
|
||||
use http::StatusCode;
|
||||
use log::error;
|
||||
use log::info;
|
||||
use md5::Digest;
|
||||
use open_message_format::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use serde_json::to_string;
|
||||
use stats::RecordingMetric;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use stats::{Gauge, RecordingMetric};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
mod common_types;
|
||||
mod configuration;
|
||||
|
|
@ -85,21 +85,47 @@ impl HttpContext for StreamContext {
|
|||
match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(msg) => panic!("Failed to deserialize: {}", msg),
|
||||
Err(msg) => {
|
||||
self.send_http_response(
|
||||
StatusCode::BAD_REQUEST.as_u16().into(),
|
||||
vec![],
|
||||
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
},
|
||||
None => panic!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
),
|
||||
None => {
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
error!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
// Modify JSON payload
|
||||
deserialized_body.model = String::from("gpt-3.5-turbo");
|
||||
let json_string = serde_json::to_string(&deserialized_body).unwrap();
|
||||
|
||||
self.set_http_request_body(0, body_size, &json_string.into_bytes());
|
||||
|
||||
Action::Continue
|
||||
match serde_json::to_string(&deserialized_body) {
|
||||
Ok(json_string) => {
|
||||
self.set_http_request_body(0, body_size, &json_string.into_bytes());
|
||||
Action::Continue
|
||||
}
|
||||
Err(error) => {
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
error!("Failed to serialize body: {}", error);
|
||||
Action::Pause
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -107,13 +133,13 @@ impl Context for StreamContext {}
|
|||
|
||||
#[derive(Copy, Clone)]
|
||||
struct WasmMetrics {
|
||||
active_http_calls: stats::Gauge,
|
||||
active_http_calls: Gauge,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
WasmMetrics {
|
||||
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -135,7 +161,13 @@ impl FilterContext {
|
|||
}
|
||||
|
||||
fn process_prompt_targets(&mut self) {
|
||||
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
|
||||
for prompt_target in &self
|
||||
.config
|
||||
.as_ref()
|
||||
.expect("Gateway configuration cannot be non-existent")
|
||||
.prompt_config
|
||||
.prompt_targets
|
||||
{
|
||||
for few_shot_example in &prompt_target.few_shot_examples {
|
||||
info!("few_shot_example: {:?}", few_shot_example);
|
||||
|
||||
|
|
@ -149,8 +181,12 @@ impl FilterContext {
|
|||
user: None,
|
||||
};
|
||||
|
||||
// TODO: Handle potential errors
|
||||
let json_data: String = to_string(&embeddings_input).unwrap();
|
||||
let json_data: String = match serde_json::to_string(&embeddings_input) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing embeddings input: {}", error);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
|
|
@ -159,6 +195,7 @@ impl FilterContext {
|
|||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
|
|
@ -166,7 +203,7 @@ impl FilterContext {
|
|||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
panic!("Error dispatching embedding server HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
|
|
@ -181,7 +218,10 @@ impl FilterContext {
|
|||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
panic!(
|
||||
"duplicate token_id={} in embedding server requests",
|
||||
token_id
|
||||
)
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
|
|
@ -197,78 +237,65 @@ impl FilterContext {
|
|||
prompt_target: PromptTarget,
|
||||
) {
|
||||
info!("response received for CreateEmbeddingRequest");
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
let embedding_response: CreateEmbeddingResponse =
|
||||
serde_json::from_slice(&body).unwrap();
|
||||
info!(
|
||||
"embedding_response model: {}, vector len: {}",
|
||||
embedding_response.model,
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
payload.insert(
|
||||
"prompt-target".to_string(),
|
||||
to_string(&prompt_target).unwrap(),
|
||||
);
|
||||
let id: Option<Digest>;
|
||||
match *create_embedding_request.input {
|
||||
CreateEmbeddingRequestInput::String(input) => {
|
||||
id = Some(md5::compute(&input));
|
||||
payload.insert("input".to_string(), input);
|
||||
}
|
||||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
|
||||
points: vec![common_types::VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
}],
|
||||
};
|
||||
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
|
||||
info!(
|
||||
"create_vector_store_points: points length: {}",
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store/points"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::StoreVectorEmbeddings(create_vector_store_points),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
let body = match self.get_http_call_response_body(0, body_size) {
|
||||
Some(body) => body,
|
||||
None => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if body.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let (json_data, create_vector_store_points) =
|
||||
match build_qdrant_data(&body, create_embedding_request, &prompt_target) {
|
||||
Ok(tuple) => tuple,
|
||||
Err(error) => {
|
||||
panic!(
|
||||
"Error building qdrant data from embedding response {}",
|
||||
error
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store/points"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching qdrant HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::StoreVectorEmbeddings(create_vector_store_points),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id={} in qdrant requests", token_id)
|
||||
}
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
|
||||
// TODO: @adilhafeez implement.
|
||||
fn create_vector_store_points_handler(&self, body_size: usize) {
|
||||
info!("response received for CreateVectorStorePoints");
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
|
|
@ -294,7 +321,6 @@ impl Context for FilterContext {
|
|||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
|
||||
match callout_data {
|
||||
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
|
||||
create_embedding_request,
|
||||
|
|
@ -313,7 +339,12 @@ impl Context for FilterContext {
|
|||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
|
||||
self.config = match serde_yaml::from_slice(&config_bytes) {
|
||||
Ok(config) => config,
|
||||
Err(error) => {
|
||||
panic!("Failed to deserialize plugin configuration: {}", error);
|
||||
}
|
||||
};
|
||||
info!("on_configure: plugin configuration loaded");
|
||||
}
|
||||
true
|
||||
|
|
@ -339,3 +370,51 @@ impl RootContext for FilterContext {
|
|||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
}
|
||||
|
||||
fn build_qdrant_data(
|
||||
embedding_data: &[u8],
|
||||
create_embedding_request: CreateEmbeddingRequest,
|
||||
prompt_target: &PromptTarget,
|
||||
) -> Result<(String, common_types::StoreVectorEmbeddingsRequest), serde_json::Error> {
|
||||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(embedding_data) {
|
||||
Ok(embedding_response) => embedding_response,
|
||||
Err(error) => {
|
||||
panic!("Failed to deserialize embedding response: {}", error);
|
||||
}
|
||||
};
|
||||
info!(
|
||||
"embedding_response model: {}, vector len: {}",
|
||||
embedding_response.model,
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
payload.insert(
|
||||
"prompt-target".to_string(),
|
||||
serde_json::to_string(&prompt_target)?,
|
||||
);
|
||||
|
||||
let id: Option<Digest>;
|
||||
match *create_embedding_request.input {
|
||||
CreateEmbeddingRequestInput::String(ref input) => {
|
||||
id = Some(md5::compute(input));
|
||||
payload.insert("input".to_string(), input.clone());
|
||||
}
|
||||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
|
||||
points: vec![common_types::VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
}],
|
||||
};
|
||||
let json_data = serde_json::to_string(&create_vector_store_points)?;
|
||||
info!(
|
||||
"create_vector_store_points: points length: {}",
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
|
||||
Ok((json_data, create_vector_store_points))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
use log::error;
|
||||
use proxy_wasm::hostcalls;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait Metric {
|
||||
fn id(&self) -> u32;
|
||||
fn value(&self) -> u64 {
|
||||
fn value(&self) -> Result<u64, String> {
|
||||
match hostcalls::get_metric(self.id()) {
|
||||
Ok(value) => value,
|
||||
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
|
||||
Err(err) => panic!("unexpected status: {:?}", err),
|
||||
Ok(value) => Ok(value),
|
||||
Err(Status::NotFound) => Err(format!("metric not found: {}", self.id())),
|
||||
Err(err) => Err(format!("unexpected status: {:?}", err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -17,9 +18,8 @@ pub trait Metric {
|
|||
pub trait IncrementingMetric: Metric {
|
||||
fn increment(&self, offset: i64) {
|
||||
match hostcalls::increment_metric(self.id(), offset) {
|
||||
Ok(data) => data,
|
||||
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
|
||||
Err(err) => panic!("unexpected status: {:?}", err),
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("error incrementing metric: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -27,9 +27,8 @@ pub trait IncrementingMetric: Metric {
|
|||
pub trait RecordingMetric: Metric {
|
||||
fn record(&self, value: u64) {
|
||||
match hostcalls::record_metric(self.id(), value) {
|
||||
Ok(data) => data,
|
||||
Err(Status::NotFound) => panic!("metric not found: {}", self.id()),
|
||||
Err(err) => panic!("unexpected status: {:?}", err),
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("error recording metric: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
use http::StatusCode;
|
||||
use proxy_wasm_test_framework::tester;
|
||||
use proxy_wasm_test_framework::types::{Action, BufferType, MapType, MetricType, ReturnType};
|
||||
use serial_test::serial;
|
||||
use std::path::Path;
|
||||
|
||||
fn wasm_module() -> String {
|
||||
|
|
@ -12,7 +14,8 @@ fn wasm_module() -> String {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn request_to_open_ai_chat_completions() {
|
||||
#[serial]
|
||||
fn successful_request_to_open_ai_chat_completions() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
|
|
@ -90,3 +93,87 @@ fn request_to_open_ai_chat_completions() {
|
|||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn bad_request_to_open_ai_chat_completions() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let root_context = 1;
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(root_context, 0)
|
||||
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, root_context)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Request Headers
|
||||
module
|
||||
.call_proxy_on_request_headers(http_context, 0, false)
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
|
||||
.returning(Some("api.openai.com"))
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("content-length"),
|
||||
Some(""),
|
||||
)
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
||||
.returning(Some("/llmrouting"))
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some(":path"),
|
||||
Some("/v1/chat/completions"),
|
||||
)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
||||
// Request Body
|
||||
let incomplete_chat_completions_request_body = "\
|
||||
{\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
]\
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
incomplete_chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(incomplete_chat_completions_request_body))
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue