refactor code a bit (#21)

This commit is contained in:
Adil Hafeez 2024-07-24 14:13:18 -07:00 committed by GitHub
parent 15a44e680b
commit a0abd9c42d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 70 additions and 61 deletions

23
envoyfilter/Cargo.lock generated
View file

@ -266,6 +266,12 @@ dependencies = [
"allocator-api2", "allocator-api2",
] ]
[[package]]
name = "hermit-abi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]] [[package]]
name = "http" name = "http"
version = "1.1.0" version = "1.1.0"
@ -490,13 +496,14 @@ dependencies = [
[[package]] [[package]]
name = "mio" name = "mio"
version = "0.8.11" version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4"
dependencies = [ dependencies = [
"hermit-abi",
"libc", "libc",
"wasi", "wasi",
"windows-sys 0.48.0", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
@ -738,9 +745,9 @@ dependencies = [
[[package]] [[package]]
name = "rustls" name = "rustls"
version = "0.23.11" version = "0.23.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4828ea528154ae444e5a642dbb7d5623354030dc9822b83fd9bb79683c7399d0" checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044"
dependencies = [ dependencies = [
"once_cell", "once_cell",
"rustls-pki-types", "rustls-pki-types",
@ -974,9 +981,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.38.1" version = "1.39.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb2caba9f80616f438e09748d5acda951967e1ea58508ef53d9c6402485a46df" checksum = "d040ac2b29ab03b09d4129c2f5bbd012a3ac2f79d38ff506a4bf8dd34b0eac8a"
dependencies = [ dependencies = [
"backtrace", "backtrace",
"bytes", "bytes",
@ -984,7 +991,7 @@ dependencies = [
"mio", "mio",
"pin-project-lite", "pin-project-lite",
"socket2", "socket2",
"windows-sys 0.48.0", "windows-sys 0.52.0",
] ]
[[package]] [[package]]

View file

@ -1,7 +1,6 @@
use open_message_format::models::CreateEmbeddingRequest; use open_message_format::models::CreateEmbeddingRequest;
use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::configuration; use crate::configuration;
@ -11,17 +10,6 @@ pub struct EmbeddingRequest {
pub prompt_target: configuration::PromptTarget, pub prompt_target: configuration::PromptTarget,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MessageType {
EmbeddingRequest(EmbeddingRequest),
CreateVectorStorePoints(CreateVectorStorePoints),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalloutData {
pub message: MessageType,
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorPoint { pub struct VectorPoint {
pub id: String, pub id: String,
@ -30,10 +18,16 @@ pub struct VectorPoint {
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateVectorStorePoints { pub struct StoreVectorEmbeddingsRequest {
pub points: Vec<VectorPoint>, pub points: Vec<VectorPoint>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CallContext {
EmbeddingRequest(EmbeddingRequest),
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
}
pub mod open_ai { pub mod open_ai {
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};

View file

@ -1,4 +1,4 @@
use common_types::EmbeddingRequest; use common_types::{CallContext, EmbeddingRequest};
use configuration::PromptTarget; use configuration::PromptTarget;
use log::info; use log::info;
use md5::Digest; use md5::Digest;
@ -6,8 +6,6 @@ use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
}; };
use serde_json::to_string; use serde_json::to_string;
use stats::IncrementingMetric;
use stats::Metric;
use stats::RecordingMetric; use stats::RecordingMetric;
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
@ -109,33 +107,21 @@ impl Context for StreamContext {}
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
struct WasmMetrics { struct WasmMetrics {
// TODO: remove example metrics once real ones are created. These ones exist in order to silence the dead code warning. active_http_calls: stats::Gauge,
example_counter: stats::Counter,
example_gauge: stats::Gauge,
example_histogram: stats::Histogram,
} }
impl WasmMetrics { impl WasmMetrics {
fn new() -> WasmMetrics { fn new() -> WasmMetrics {
let new_metrics = WasmMetrics { WasmMetrics {
example_counter: stats::Counter::new(String::from("example_counter")), active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
example_gauge: stats::Gauge::new(String::from("example_gauge")), }
example_histogram: stats::Histogram::new(String::from("example_histogram")),
};
new_metrics.example_counter.increment(10);
new_metrics.example_counter.value();
new_metrics.example_gauge.record(20);
new_metrics.example_histogram.record(30);
new_metrics
} }
} }
struct FilterContext { struct FilterContext {
_metrics: WasmMetrics, metrics: WasmMetrics,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, common_types::CalloutData>, callouts: HashMap<u32, common_types::CallContext>,
config: Option<configuration::Configuration>, config: Option<configuration::Configuration>,
} }
@ -144,7 +130,7 @@ impl FilterContext {
FilterContext { FilterContext {
callouts: HashMap::new(), callouts: HashMap::new(),
config: None, config: None,
_metrics: WasmMetrics::new(), metrics: WasmMetrics::new(),
} }
} }
@ -164,7 +150,7 @@ impl FilterContext {
}; };
// TODO: Handle potential errors // TODO: Handle potential errors
let json_data = to_string(&embeddings_input).unwrap(); let json_data: String = to_string(&embeddings_input).unwrap();
let token_id = match self.dispatch_http_call( let token_id = match self.dispatch_http_call(
"embeddingserver", "embeddingserver",
@ -188,12 +174,18 @@ impl FilterContext {
create_embedding_request: embeddings_input, create_embedding_request: embeddings_input,
prompt_target: prompt_target.clone(), prompt_target: prompt_target.clone(),
}; };
let callout_message = common_types::CalloutData { if self
message: common_types::MessageType::EmbeddingRequest(embedding_request), .callouts
}; .insert(token_id, {
if self.callouts.insert(token_id, callout_message).is_some() { CallContext::EmbeddingRequest(embedding_request)
})
.is_some()
{
panic!("duplicate token_id") panic!("duplicate token_id")
} }
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
} }
} }
} }
@ -229,7 +221,7 @@ impl FilterContext {
CreateEmbeddingRequestInput::Array(_) => todo!(), CreateEmbeddingRequestInput::Array(_) => todo!(),
} }
let create_vector_store_points = common_types::CreateVectorStorePoints { let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
points: vec![common_types::VectorPoint { points: vec![common_types::VectorPoint {
id: format!("{:x}", id.unwrap()), id: format!("{:x}", id.unwrap()),
payload, payload,
@ -260,14 +252,19 @@ impl FilterContext {
}; };
info!("on_tick: dispatched HTTP call with token_id = {}", token_id); info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
let callout_message = common_types::CalloutData { if self
message: common_types::MessageType::CreateVectorStorePoints( .callouts
create_vector_store_points, .insert(
), token_id,
}; CallContext::StoreVectorEmbeddings(create_vector_store_points),
if self.callouts.insert(token_id, callout_message).is_some() { )
.is_some()
{
panic!("duplicate token_id") panic!("duplicate token_id")
} }
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
} }
} }
} }
@ -294,14 +291,18 @@ impl Context for FilterContext {
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id"); let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
match callout_data.message { self.metrics
common_types::MessageType::EmbeddingRequest(common_types::EmbeddingRequest { .active_http_calls
.record(self.callouts.len().try_into().unwrap());
match callout_data {
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
create_embedding_request, create_embedding_request,
prompt_target, prompt_target,
}) => { }) => {
self.embedding_request_handler(body_size, create_embedding_request, prompt_target) self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
} }
common_types::MessageType::CreateVectorStorePoints(_) => { common_types::CallContext::StoreVectorEmbeddings(_) => {
self.create_vector_store_points_handler(body_size) self.create_vector_store_points_handler(body_size)
} }
} }

View file

@ -1,9 +1,9 @@
use proxy_wasm::hostcalls; use proxy_wasm::hostcalls;
use proxy_wasm::types::*; use proxy_wasm::types::*;
#[allow(unused)]
pub trait Metric { pub trait Metric {
fn id(&self) -> u32; fn id(&self) -> u32;
fn value(&self) -> u64 { fn value(&self) -> u64 {
match hostcalls::get_metric(self.id()) { match hostcalls::get_metric(self.id()) {
Ok(value) => value, Ok(value) => value,
@ -13,6 +13,7 @@ pub trait Metric {
} }
} }
#[allow(unused)]
pub trait IncrementingMetric: Metric { pub trait IncrementingMetric: Metric {
fn increment(&self, offset: i64) { fn increment(&self, offset: i64) {
match hostcalls::increment_metric(self.id(), offset) { match hostcalls::increment_metric(self.id(), offset) {
@ -38,6 +39,7 @@ pub struct Counter {
id: u32, id: u32,
} }
#[allow(unused)]
impl Counter { impl Counter {
pub fn new(name: String) -> Counter { pub fn new(name: String) -> Counter {
let returned_id = hostcalls::define_metric(MetricType::Counter, &name) let returned_id = hostcalls::define_metric(MetricType::Counter, &name)
@ -80,6 +82,7 @@ pub struct Histogram {
id: u32, id: u32,
} }
#[allow(unused)]
impl Histogram { impl Histogram {
pub fn new(name: String) -> Histogram { pub fn new(name: String) -> Histogram {
let returned_id = hostcalls::define_metric(MetricType::Histogram, &name) let returned_id = hostcalls::define_metric(MetricType::Histogram, &name)

View file

@ -12,6 +12,10 @@
"name": "embedding-server", "name": "embedding-server",
"path": "embedding-server" "path": "embedding-server"
}, },
{
"name": "chatbot-client",
"path": "chatbot-client"
},
{ {
"name": "open-message-format", "name": "open-message-format",
"path": "open-message-format" "path": "open-message-format"

@ -1 +1 @@
Subproject commit f70179ae3de7110cb40412c902b275b7900b40a1 Subproject commit 1469e1ec49dd9b7fcf91e16fea4781294ced7660