mirror of
https://github.com/katanemo/plano.git
synced 2026-05-04 05:12:55 +02:00
refactor code a bit (#21)
This commit is contained in:
parent
15a44e680b
commit
a0abd9c42d
6 changed files with 70 additions and 61 deletions
|
|
@ -1,7 +1,6 @@
|
|||
use open_message_format::models::CreateEmbeddingRequest;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::configuration;
|
||||
|
||||
|
|
@ -11,17 +10,6 @@ pub struct EmbeddingRequest {
|
|||
pub prompt_target: configuration::PromptTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum MessageType {
|
||||
EmbeddingRequest(EmbeddingRequest),
|
||||
CreateVectorStorePoints(CreateVectorStorePoints),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CalloutData {
|
||||
pub message: MessageType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorPoint {
|
||||
pub id: String,
|
||||
|
|
@ -30,10 +18,16 @@ pub struct VectorPoint {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateVectorStorePoints {
|
||||
pub struct StoreVectorEmbeddingsRequest {
|
||||
pub points: Vec<VectorPoint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum CallContext {
|
||||
EmbeddingRequest(EmbeddingRequest),
|
||||
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
|
||||
}
|
||||
|
||||
pub mod open_ai {
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use common_types::EmbeddingRequest;
|
||||
use common_types::{CallContext, EmbeddingRequest};
|
||||
use configuration::PromptTarget;
|
||||
use log::info;
|
||||
use md5::Digest;
|
||||
|
|
@ -6,8 +6,6 @@ use open_message_format::models::{
|
|||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use serde_json::to_string;
|
||||
use stats::IncrementingMetric;
|
||||
use stats::Metric;
|
||||
use stats::RecordingMetric;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
|
@ -109,33 +107,21 @@ impl Context for StreamContext {}
|
|||
|
||||
#[derive(Copy, Clone)]
|
||||
struct WasmMetrics {
|
||||
// TODO: remove example metrics once real ones are created. These ones exist in order to silence the dead code warning.
|
||||
example_counter: stats::Counter,
|
||||
example_gauge: stats::Gauge,
|
||||
example_histogram: stats::Histogram,
|
||||
active_http_calls: stats::Gauge,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
let new_metrics = WasmMetrics {
|
||||
example_counter: stats::Counter::new(String::from("example_counter")),
|
||||
example_gauge: stats::Gauge::new(String::from("example_gauge")),
|
||||
example_histogram: stats::Histogram::new(String::from("example_histogram")),
|
||||
};
|
||||
|
||||
new_metrics.example_counter.increment(10);
|
||||
new_metrics.example_counter.value();
|
||||
new_metrics.example_gauge.record(20);
|
||||
new_metrics.example_histogram.record(30);
|
||||
|
||||
new_metrics
|
||||
WasmMetrics {
|
||||
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct FilterContext {
|
||||
_metrics: WasmMetrics,
|
||||
metrics: WasmMetrics,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, common_types::CalloutData>,
|
||||
callouts: HashMap<u32, common_types::CallContext>,
|
||||
config: Option<configuration::Configuration>,
|
||||
}
|
||||
|
||||
|
|
@ -144,7 +130,7 @@ impl FilterContext {
|
|||
FilterContext {
|
||||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
_metrics: WasmMetrics::new(),
|
||||
metrics: WasmMetrics::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -164,7 +150,7 @@ impl FilterContext {
|
|||
};
|
||||
|
||||
// TODO: Handle potential errors
|
||||
let json_data = to_string(&embeddings_input).unwrap();
|
||||
let json_data: String = to_string(&embeddings_input).unwrap();
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
|
|
@ -188,12 +174,18 @@ impl FilterContext {
|
|||
create_embedding_request: embeddings_input,
|
||||
prompt_target: prompt_target.clone(),
|
||||
};
|
||||
let callout_message = common_types::CalloutData {
|
||||
message: common_types::MessageType::EmbeddingRequest(embedding_request),
|
||||
};
|
||||
if self.callouts.insert(token_id, callout_message).is_some() {
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext::EmbeddingRequest(embedding_request)
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -229,7 +221,7 @@ impl FilterContext {
|
|||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = common_types::CreateVectorStorePoints {
|
||||
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
|
||||
points: vec![common_types::VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
|
|
@ -260,14 +252,19 @@ impl FilterContext {
|
|||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
|
||||
let callout_message = common_types::CalloutData {
|
||||
message: common_types::MessageType::CreateVectorStorePoints(
|
||||
create_vector_store_points,
|
||||
),
|
||||
};
|
||||
if self.callouts.insert(token_id, callout_message).is_some() {
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::StoreVectorEmbeddings(create_vector_store_points),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -294,14 +291,18 @@ impl Context for FilterContext {
|
|||
|
||||
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
|
||||
match callout_data.message {
|
||||
common_types::MessageType::EmbeddingRequest(common_types::EmbeddingRequest {
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
|
||||
match callout_data {
|
||||
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
|
||||
create_embedding_request,
|
||||
prompt_target,
|
||||
}) => {
|
||||
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
|
||||
}
|
||||
common_types::MessageType::CreateVectorStorePoints(_) => {
|
||||
common_types::CallContext::StoreVectorEmbeddings(_) => {
|
||||
self.create_vector_store_points_handler(body_size)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
use proxy_wasm::hostcalls;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait Metric {
|
||||
fn id(&self) -> u32;
|
||||
|
||||
fn value(&self) -> u64 {
|
||||
match hostcalls::get_metric(self.id()) {
|
||||
Ok(value) => value,
|
||||
|
|
@ -13,6 +13,7 @@ pub trait Metric {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait IncrementingMetric: Metric {
|
||||
fn increment(&self, offset: i64) {
|
||||
match hostcalls::increment_metric(self.id(), offset) {
|
||||
|
|
@ -38,6 +39,7 @@ pub struct Counter {
|
|||
id: u32,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl Counter {
|
||||
pub fn new(name: String) -> Counter {
|
||||
let returned_id = hostcalls::define_metric(MetricType::Counter, &name)
|
||||
|
|
@ -80,6 +82,7 @@ pub struct Histogram {
|
|||
id: u32,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl Histogram {
|
||||
pub fn new(name: String) -> Histogram {
|
||||
let returned_id = hostcalls::define_metric(MetricType::Histogram, &name)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue