refactor code a bit (#21)

This commit is contained in:
Adil Hafeez 2024-07-24 14:13:18 -07:00 committed by GitHub
parent 15a44e680b
commit a0abd9c42d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 70 additions and 61 deletions

23
envoyfilter/Cargo.lock generated
View file

@ -266,6 +266,12 @@ dependencies = [
"allocator-api2",
]
[[package]]
name = "hermit-abi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "http"
version = "1.1.0"
@ -490,13 +496,14 @@ dependencies = [
[[package]]
name = "mio"
version = "0.8.11"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c"
checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4"
dependencies = [
"hermit-abi",
"libc",
"wasi",
"windows-sys 0.48.0",
"windows-sys 0.52.0",
]
[[package]]
@ -738,9 +745,9 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.11"
version = "0.23.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4828ea528154ae444e5a642dbb7d5623354030dc9822b83fd9bb79683c7399d0"
checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044"
dependencies = [
"once_cell",
"rustls-pki-types",
@ -974,9 +981,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.38.1"
version = "1.39.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb2caba9f80616f438e09748d5acda951967e1ea58508ef53d9c6402485a46df"
checksum = "d040ac2b29ab03b09d4129c2f5bbd012a3ac2f79d38ff506a4bf8dd34b0eac8a"
dependencies = [
"backtrace",
"bytes",
@ -984,7 +991,7 @@ dependencies = [
"mio",
"pin-project-lite",
"socket2",
"windows-sys 0.48.0",
"windows-sys 0.52.0",
]
[[package]]

View file

@ -1,7 +1,6 @@
use open_message_format::models::CreateEmbeddingRequest;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::configuration;
@ -11,17 +10,6 @@ pub struct EmbeddingRequest {
pub prompt_target: configuration::PromptTarget,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MessageType {
EmbeddingRequest(EmbeddingRequest),
CreateVectorStorePoints(CreateVectorStorePoints),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalloutData {
pub message: MessageType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorPoint {
pub id: String,
@ -30,10 +18,16 @@ pub struct VectorPoint {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateVectorStorePoints {
pub struct StoreVectorEmbeddingsRequest {
pub points: Vec<VectorPoint>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CallContext {
EmbeddingRequest(EmbeddingRequest),
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
}
pub mod open_ai {
use serde::{Deserialize, Serialize};

View file

@ -1,4 +1,4 @@
use common_types::EmbeddingRequest;
use common_types::{CallContext, EmbeddingRequest};
use configuration::PromptTarget;
use log::info;
use md5::Digest;
@ -6,8 +6,6 @@ use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use serde_json::to_string;
use stats::IncrementingMetric;
use stats::Metric;
use stats::RecordingMetric;
use std::collections::HashMap;
use std::time::Duration;
@ -109,33 +107,21 @@ impl Context for StreamContext {}
#[derive(Copy, Clone)]
struct WasmMetrics {
// TODO: remove example metrics once real ones are created. These ones exist in order to silence the dead code warning.
example_counter: stats::Counter,
example_gauge: stats::Gauge,
example_histogram: stats::Histogram,
active_http_calls: stats::Gauge,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
let new_metrics = WasmMetrics {
example_counter: stats::Counter::new(String::from("example_counter")),
example_gauge: stats::Gauge::new(String::from("example_gauge")),
example_histogram: stats::Histogram::new(String::from("example_histogram")),
};
new_metrics.example_counter.increment(10);
new_metrics.example_counter.value();
new_metrics.example_gauge.record(20);
new_metrics.example_histogram.record(30);
new_metrics
WasmMetrics {
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
}
}
}
struct FilterContext {
_metrics: WasmMetrics,
metrics: WasmMetrics,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, common_types::CalloutData>,
callouts: HashMap<u32, common_types::CallContext>,
config: Option<configuration::Configuration>,
}
@ -144,7 +130,7 @@ impl FilterContext {
FilterContext {
callouts: HashMap::new(),
config: None,
_metrics: WasmMetrics::new(),
metrics: WasmMetrics::new(),
}
}
@ -164,7 +150,7 @@ impl FilterContext {
};
// TODO: Handle potential errors
let json_data = to_string(&embeddings_input).unwrap();
let json_data: String = to_string(&embeddings_input).unwrap();
let token_id = match self.dispatch_http_call(
"embeddingserver",
@ -188,12 +174,18 @@ impl FilterContext {
create_embedding_request: embeddings_input,
prompt_target: prompt_target.clone(),
};
let callout_message = common_types::CalloutData {
message: common_types::MessageType::EmbeddingRequest(embedding_request),
};
if self.callouts.insert(token_id, callout_message).is_some() {
if self
.callouts
.insert(token_id, {
CallContext::EmbeddingRequest(embedding_request)
})
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
}
@ -229,7 +221,7 @@ impl FilterContext {
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = common_types::CreateVectorStorePoints {
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
points: vec![common_types::VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
@ -260,14 +252,19 @@ impl FilterContext {
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
let callout_message = common_types::CalloutData {
message: common_types::MessageType::CreateVectorStorePoints(
create_vector_store_points,
),
};
if self.callouts.insert(token_id, callout_message).is_some() {
if self
.callouts
.insert(
token_id,
CallContext::StoreVectorEmbeddings(create_vector_store_points),
)
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
}
@ -294,14 +291,18 @@ impl Context for FilterContext {
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
match callout_data.message {
common_types::MessageType::EmbeddingRequest(common_types::EmbeddingRequest {
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
match callout_data {
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
create_embedding_request,
prompt_target,
}) => {
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
}
common_types::MessageType::CreateVectorStorePoints(_) => {
common_types::CallContext::StoreVectorEmbeddings(_) => {
self.create_vector_store_points_handler(body_size)
}
}

View file

@ -1,9 +1,9 @@
use proxy_wasm::hostcalls;
use proxy_wasm::types::*;
#[allow(unused)]
pub trait Metric {
fn id(&self) -> u32;
fn value(&self) -> u64 {
match hostcalls::get_metric(self.id()) {
Ok(value) => value,
@ -13,6 +13,7 @@ pub trait Metric {
}
}
#[allow(unused)]
pub trait IncrementingMetric: Metric {
fn increment(&self, offset: i64) {
match hostcalls::increment_metric(self.id(), offset) {
@ -38,6 +39,7 @@ pub struct Counter {
id: u32,
}
#[allow(unused)]
impl Counter {
pub fn new(name: String) -> Counter {
let returned_id = hostcalls::define_metric(MetricType::Counter, &name)
@ -80,6 +82,7 @@ pub struct Histogram {
id: u32,
}
#[allow(unused)]
impl Histogram {
pub fn new(name: String) -> Histogram {
let returned_id = hostcalls::define_metric(MetricType::Histogram, &name)

View file

@ -12,6 +12,10 @@
"name": "embedding-server",
"path": "embedding-server"
},
{
"name": "chatbot-client",
"path": "chatbot-client"
},
{
"name": "open-message-format",
"path": "open-message-format"

@ -1 +1 @@
Subproject commit f70179ae3de7110cb40412c902b275b7900b40a1
Subproject commit 1469e1ec49dd9b7fcf91e16fea4781294ced7660