refactor code a bit (#21)

This commit is contained in:
Adil Hafeez 2024-07-24 14:13:18 -07:00 committed by GitHub
parent 15a44e680b
commit a0abd9c42d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 70 additions and 61 deletions

View file

@ -1,7 +1,6 @@
use open_message_format::models::CreateEmbeddingRequest;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::configuration;
@ -11,17 +10,6 @@ pub struct EmbeddingRequest {
pub prompt_target: configuration::PromptTarget,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MessageType {
EmbeddingRequest(EmbeddingRequest),
CreateVectorStorePoints(CreateVectorStorePoints),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalloutData {
pub message: MessageType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorPoint {
pub id: String,
@ -30,10 +18,16 @@ pub struct VectorPoint {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateVectorStorePoints {
pub struct StoreVectorEmbeddingsRequest {
pub points: Vec<VectorPoint>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CallContext {
EmbeddingRequest(EmbeddingRequest),
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
}
pub mod open_ai {
use serde::{Deserialize, Serialize};

View file

@ -1,4 +1,4 @@
use common_types::EmbeddingRequest;
use common_types::{CallContext, EmbeddingRequest};
use configuration::PromptTarget;
use log::info;
use md5::Digest;
@ -6,8 +6,6 @@ use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use serde_json::to_string;
use stats::IncrementingMetric;
use stats::Metric;
use stats::RecordingMetric;
use std::collections::HashMap;
use std::time::Duration;
@ -109,33 +107,21 @@ impl Context for StreamContext {}
#[derive(Copy, Clone)]
struct WasmMetrics {
// TODO: remove example metrics once real ones are created. These ones exist in order to silence the dead code warning.
example_counter: stats::Counter,
example_gauge: stats::Gauge,
example_histogram: stats::Histogram,
active_http_calls: stats::Gauge,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
let new_metrics = WasmMetrics {
example_counter: stats::Counter::new(String::from("example_counter")),
example_gauge: stats::Gauge::new(String::from("example_gauge")),
example_histogram: stats::Histogram::new(String::from("example_histogram")),
};
new_metrics.example_counter.increment(10);
new_metrics.example_counter.value();
new_metrics.example_gauge.record(20);
new_metrics.example_histogram.record(30);
new_metrics
WasmMetrics {
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
}
}
}
struct FilterContext {
_metrics: WasmMetrics,
metrics: WasmMetrics,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, common_types::CalloutData>,
callouts: HashMap<u32, common_types::CallContext>,
config: Option<configuration::Configuration>,
}
@ -144,7 +130,7 @@ impl FilterContext {
FilterContext {
callouts: HashMap::new(),
config: None,
_metrics: WasmMetrics::new(),
metrics: WasmMetrics::new(),
}
}
@ -164,7 +150,7 @@ impl FilterContext {
};
// TODO: Handle potential errors
let json_data = to_string(&embeddings_input).unwrap();
let json_data: String = to_string(&embeddings_input).unwrap();
let token_id = match self.dispatch_http_call(
"embeddingserver",
@ -188,12 +174,18 @@ impl FilterContext {
create_embedding_request: embeddings_input,
prompt_target: prompt_target.clone(),
};
let callout_message = common_types::CalloutData {
message: common_types::MessageType::EmbeddingRequest(embedding_request),
};
if self.callouts.insert(token_id, callout_message).is_some() {
if self
.callouts
.insert(token_id, {
CallContext::EmbeddingRequest(embedding_request)
})
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
}
@ -229,7 +221,7 @@ impl FilterContext {
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = common_types::CreateVectorStorePoints {
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
points: vec![common_types::VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
@ -260,14 +252,19 @@ impl FilterContext {
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
let callout_message = common_types::CalloutData {
message: common_types::MessageType::CreateVectorStorePoints(
create_vector_store_points,
),
};
if self.callouts.insert(token_id, callout_message).is_some() {
if self
.callouts
.insert(
token_id,
CallContext::StoreVectorEmbeddings(create_vector_store_points),
)
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
}
@ -294,14 +291,18 @@ impl Context for FilterContext {
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
match callout_data.message {
common_types::MessageType::EmbeddingRequest(common_types::EmbeddingRequest {
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
match callout_data {
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
create_embedding_request,
prompt_target,
}) => {
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
}
common_types::MessageType::CreateVectorStorePoints(_) => {
common_types::CallContext::StoreVectorEmbeddings(_) => {
self.create_vector_store_points_handler(body_size)
}
}

View file

@ -1,9 +1,9 @@
use proxy_wasm::hostcalls;
use proxy_wasm::types::*;
#[allow(unused)]
pub trait Metric {
fn id(&self) -> u32;
fn value(&self) -> u64 {
match hostcalls::get_metric(self.id()) {
Ok(value) => value,
@ -13,6 +13,7 @@ pub trait Metric {
}
}
#[allow(unused)]
pub trait IncrementingMetric: Metric {
fn increment(&self, offset: i64) {
match hostcalls::increment_metric(self.id(), offset) {
@ -38,6 +39,7 @@ pub struct Counter {
id: u32,
}
#[allow(unused)]
impl Counter {
pub fn new(name: String) -> Counter {
let returned_id = hostcalls::define_metric(MetricType::Counter, &name)
@ -80,6 +82,7 @@ pub struct Histogram {
id: u32,
}
#[allow(unused)]
impl Histogram {
pub fn new(name: String) -> Histogram {
let returned_id = hostcalls::define_metric(MetricType::Histogram, &name)