mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
refactor code a bit (#21)
This commit is contained in:
parent
15a44e680b
commit
a0abd9c42d
6 changed files with 70 additions and 61 deletions
23
envoyfilter/Cargo.lock
generated
23
envoyfilter/Cargo.lock
generated
|
|
@ -266,6 +266,12 @@ dependencies = [
|
|||
"allocator-api2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "1.1.0"
|
||||
|
|
@ -490,13 +496,14 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "0.8.11"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c"
|
||||
checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
"wasi",
|
||||
"windows-sys 0.48.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -738,9 +745,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.11"
|
||||
version = "0.23.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4828ea528154ae444e5a642dbb7d5623354030dc9822b83fd9bb79683c7399d0"
|
||||
checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"rustls-pki-types",
|
||||
|
|
@ -974,9 +981,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
|||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.38.1"
|
||||
version = "1.39.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb2caba9f80616f438e09748d5acda951967e1ea58508ef53d9c6402485a46df"
|
||||
checksum = "d040ac2b29ab03b09d4129c2f5bbd012a3ac2f79d38ff506a4bf8dd34b0eac8a"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
"bytes",
|
||||
|
|
@ -984,7 +991,7 @@ dependencies = [
|
|||
"mio",
|
||||
"pin-project-lite",
|
||||
"socket2",
|
||||
"windows-sys 0.48.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
use open_message_format::models::CreateEmbeddingRequest;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::configuration;
|
||||
|
||||
|
|
@ -11,17 +10,6 @@ pub struct EmbeddingRequest {
|
|||
pub prompt_target: configuration::PromptTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum MessageType {
|
||||
EmbeddingRequest(EmbeddingRequest),
|
||||
CreateVectorStorePoints(CreateVectorStorePoints),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CalloutData {
|
||||
pub message: MessageType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorPoint {
|
||||
pub id: String,
|
||||
|
|
@ -30,10 +18,16 @@ pub struct VectorPoint {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateVectorStorePoints {
|
||||
pub struct StoreVectorEmbeddingsRequest {
|
||||
pub points: Vec<VectorPoint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum CallContext {
|
||||
EmbeddingRequest(EmbeddingRequest),
|
||||
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
|
||||
}
|
||||
|
||||
pub mod open_ai {
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use common_types::EmbeddingRequest;
|
||||
use common_types::{CallContext, EmbeddingRequest};
|
||||
use configuration::PromptTarget;
|
||||
use log::info;
|
||||
use md5::Digest;
|
||||
|
|
@ -6,8 +6,6 @@ use open_message_format::models::{
|
|||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use serde_json::to_string;
|
||||
use stats::IncrementingMetric;
|
||||
use stats::Metric;
|
||||
use stats::RecordingMetric;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
|
@ -109,33 +107,21 @@ impl Context for StreamContext {}
|
|||
|
||||
#[derive(Copy, Clone)]
|
||||
struct WasmMetrics {
|
||||
// TODO: remove example metrics once real ones are created. These ones exist in order to silence the dead code warning.
|
||||
example_counter: stats::Counter,
|
||||
example_gauge: stats::Gauge,
|
||||
example_histogram: stats::Histogram,
|
||||
active_http_calls: stats::Gauge,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
let new_metrics = WasmMetrics {
|
||||
example_counter: stats::Counter::new(String::from("example_counter")),
|
||||
example_gauge: stats::Gauge::new(String::from("example_gauge")),
|
||||
example_histogram: stats::Histogram::new(String::from("example_histogram")),
|
||||
};
|
||||
|
||||
new_metrics.example_counter.increment(10);
|
||||
new_metrics.example_counter.value();
|
||||
new_metrics.example_gauge.record(20);
|
||||
new_metrics.example_histogram.record(30);
|
||||
|
||||
new_metrics
|
||||
WasmMetrics {
|
||||
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct FilterContext {
|
||||
_metrics: WasmMetrics,
|
||||
metrics: WasmMetrics,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, common_types::CalloutData>,
|
||||
callouts: HashMap<u32, common_types::CallContext>,
|
||||
config: Option<configuration::Configuration>,
|
||||
}
|
||||
|
||||
|
|
@ -144,7 +130,7 @@ impl FilterContext {
|
|||
FilterContext {
|
||||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
_metrics: WasmMetrics::new(),
|
||||
metrics: WasmMetrics::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -164,7 +150,7 @@ impl FilterContext {
|
|||
};
|
||||
|
||||
// TODO: Handle potential errors
|
||||
let json_data = to_string(&embeddings_input).unwrap();
|
||||
let json_data: String = to_string(&embeddings_input).unwrap();
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
|
|
@ -188,12 +174,18 @@ impl FilterContext {
|
|||
create_embedding_request: embeddings_input,
|
||||
prompt_target: prompt_target.clone(),
|
||||
};
|
||||
let callout_message = common_types::CalloutData {
|
||||
message: common_types::MessageType::EmbeddingRequest(embedding_request),
|
||||
};
|
||||
if self.callouts.insert(token_id, callout_message).is_some() {
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext::EmbeddingRequest(embedding_request)
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -229,7 +221,7 @@ impl FilterContext {
|
|||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = common_types::CreateVectorStorePoints {
|
||||
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
|
||||
points: vec![common_types::VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
|
|
@ -260,14 +252,19 @@ impl FilterContext {
|
|||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
|
||||
let callout_message = common_types::CalloutData {
|
||||
message: common_types::MessageType::CreateVectorStorePoints(
|
||||
create_vector_store_points,
|
||||
),
|
||||
};
|
||||
if self.callouts.insert(token_id, callout_message).is_some() {
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::StoreVectorEmbeddings(create_vector_store_points),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -294,14 +291,18 @@ impl Context for FilterContext {
|
|||
|
||||
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
|
||||
match callout_data.message {
|
||||
common_types::MessageType::EmbeddingRequest(common_types::EmbeddingRequest {
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
|
||||
match callout_data {
|
||||
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
|
||||
create_embedding_request,
|
||||
prompt_target,
|
||||
}) => {
|
||||
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
|
||||
}
|
||||
common_types::MessageType::CreateVectorStorePoints(_) => {
|
||||
common_types::CallContext::StoreVectorEmbeddings(_) => {
|
||||
self.create_vector_store_points_handler(body_size)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
use proxy_wasm::hostcalls;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait Metric {
|
||||
fn id(&self) -> u32;
|
||||
|
||||
fn value(&self) -> u64 {
|
||||
match hostcalls::get_metric(self.id()) {
|
||||
Ok(value) => value,
|
||||
|
|
@ -13,6 +13,7 @@ pub trait Metric {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait IncrementingMetric: Metric {
|
||||
fn increment(&self, offset: i64) {
|
||||
match hostcalls::increment_metric(self.id(), offset) {
|
||||
|
|
@ -38,6 +39,7 @@ pub struct Counter {
|
|||
id: u32,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl Counter {
|
||||
pub fn new(name: String) -> Counter {
|
||||
let returned_id = hostcalls::define_metric(MetricType::Counter, &name)
|
||||
|
|
@ -80,6 +82,7 @@ pub struct Histogram {
|
|||
id: u32,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl Histogram {
|
||||
pub fn new(name: String) -> Histogram {
|
||||
let returned_id = hostcalls::define_metric(MetricType::Histogram, &name)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@
|
|||
"name": "embedding-server",
|
||||
"path": "embedding-server"
|
||||
},
|
||||
{
|
||||
"name": "chatbot-client",
|
||||
"path": "chatbot-client"
|
||||
},
|
||||
{
|
||||
"name": "open-message-format",
|
||||
"path": "open-message-format"
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Subproject commit f70179ae3de7110cb40412c902b275b7900b40a1
|
||||
Subproject commit 1469e1ec49dd9b7fcf91e16fea4781294ced7660
|
||||
Loading…
Add table
Add a link
Reference in a new issue