diff --git a/envoyfilter/src/common_types.rs b/envoyfilter/src/common_types.rs index 3d82eca4..162a7511 100644 --- a/envoyfilter/src/common_types.rs +++ b/envoyfilter/src/common_types.rs @@ -1,13 +1,12 @@ +use crate::configuration::PromptTarget; use open_message_format::models::CreateEmbeddingRequest; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use crate::configuration; - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct EmbeddingRequest { pub create_embedding_request: CreateEmbeddingRequest, - pub prompt_target: configuration::PromptTarget, + pub prompt_target: PromptTarget, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/envoyfilter/src/filter_context.rs b/envoyfilter/src/filter_context.rs index 21654a8e..963e7719 100644 --- a/envoyfilter/src/filter_context.rs +++ b/envoyfilter/src/filter_context.rs @@ -1,34 +1,30 @@ -use common_types::{CallContext, EmbeddingRequest}; -use configuration::PromptTarget; +use crate::common_types::{ + CallContext, EmbeddingRequest, StoreVectorEmbeddingsRequest, VectorPoint, +}; +use crate::configuration::{Configuration, PromptTarget}; +use crate::consts::DEFAULT_EMBEDDING_MODEL; +use crate::stats::{Gauge, RecordingMetric}; +use crate::stream_context::StreamContext; use log::info; use md5::Digest; use open_message_format::models::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; use serde_json::to_string; -use stats::RecordingMetric; use std::collections::HashMap; use std::time::Duration; -use consts::DEFAULT_EMBEDDING_MODEL; -use proxy_wasm::traits::*; -use proxy_wasm::types::*; - -use crate::common_types; -use crate::configuration; -use crate::consts; -use crate::stats; -use crate::stream_context::StreamContext; - #[derive(Copy, Clone)] struct WasmMetrics { - active_http_calls: stats::Gauge, + active_http_calls: Gauge, } impl WasmMetrics { fn new() -> WasmMetrics { WasmMetrics { - active_http_calls: stats::Gauge::new(String::from("active_http_calls")), + active_http_calls: Gauge::new(String::from("active_http_calls")), } } } @@ -36,8 +32,8 @@ impl WasmMetrics { pub struct FilterContext { metrics: WasmMetrics, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. - callouts: HashMap, - config: Option, + callouts: HashMap, + config: Option, } impl FilterContext { @@ -127,8 +123,8 @@ impl FilterContext { CreateEmbeddingRequestInput::Array(_) => todo!(), } - let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest { - points: vec![common_types::VectorPoint { + let create_vector_store_points = StoreVectorEmbeddingsRequest { + points: vec![VectorPoint { id: format!("{:x}", id.unwrap()), payload, vector: embedding_response.data[0].embedding.clone(), @@ -231,16 +227,16 @@ impl Context for FilterContext { .record(self.callouts.len().try_into().unwrap()); match callout_data { - common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest { + CallContext::EmbeddingRequest(EmbeddingRequest { create_embedding_request, prompt_target, }) => { self.embedding_request_handler(body_size, create_embedding_request, prompt_target) } - common_types::CallContext::StoreVectorEmbeddings(_) => { + CallContext::StoreVectorEmbeddings(_) => { self.create_vector_store_points_handler(body_size) } - common_types::CallContext::CreateVectorCollection(_) => { + CallContext::CreateVectorCollection(_) => { let mut http_status_code = "Nil".to_string(); self.get_http_call_response_headers() .iter() diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index 74503c70..9dfb7076 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -1,3 +1,14 @@ +use crate::common_types::{ + open_ai::{ChatCompletions, Message}, + NERRequest, NERResponse, SearchPointsRequest, SearchPointsResponse, +}; +use crate::configuration::EntityDetail; +use crate::configuration::EntityType; +use crate::configuration::PromptTarget; +use crate::consts::{ + DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD, + DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE, +}; use http::StatusCode; use log::error; use log::info; @@ -5,24 +16,10 @@ use log::warn; use open_message_format::models::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; -use std::collections::HashMap; -use std::time::Duration; - use proxy_wasm::traits::*; use proxy_wasm::types::*; - -use consts::{ - DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD, - DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE, -}; - -use crate::common_types; -use crate::common_types::open_ai::Message; -use crate::common_types::SearchPointsResponse; -use crate::configuration::EntityDetail; -use crate::configuration::EntityType; -use crate::configuration::PromptTarget; -use crate::consts; +use std::collections::HashMap; +use std::time::Duration; enum RequestType { GetEmbedding, @@ -35,7 +32,7 @@ pub struct CallContext { request_type: RequestType, user_message: String, prompt_target: Option, - request_body: common_types::open_ai::ChatCompletions, + request_body: ChatCompletions, } pub struct StreamContext { @@ -79,7 +76,7 @@ impl StreamContext { } }; - let search_points_request = common_types::SearchPointsRequest { + let search_points_request = SearchPointsRequest { vector: embedding_response.data[0].embedding.clone(), limit: 10, with_payload: true, @@ -167,7 +164,7 @@ impl StreamContext { .map(|entity| entity.name.clone()) .collect(); let user_message = callout_context.user_message.clone(); - let ner_request = common_types::NERRequest { + let ner_request = NERRequest { input: user_message, labels: entity_names, model: DEFAULT_NER_MODEL.to_string(), @@ -208,7 +205,7 @@ impl StreamContext { } fn ner_handler(&mut self, body: Vec, mut callout_context: CallContext) { - let ner_response: common_types::NERResponse = match serde_json::from_slice(&body) { + let ner_response: NERResponse = match serde_json::from_slice(&body) { Ok(ner_response) => ner_response, Err(e) => { warn!("Error deserializing ner_response: {:?}", e); @@ -364,32 +361,31 @@ impl HttpContext for StreamContext { // Deserialize body into spec. // Currently OpenAI API. - let deserialized_body: common_types::open_ai::ChatCompletions = - match self.get_http_request_body(0, body_size) { - Some(body_bytes) => match serde_json::from_slice(&body_bytes) { - Ok(deserialized) => deserialized, - Err(msg) => { - self.send_http_response( - StatusCode::BAD_REQUEST.as_u16().into(), - vec![], - Some(format!("Failed to deserialize: {}", msg).as_bytes()), - ); - return Action::Pause; - } - }, - None => { + let deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) { + Some(body_bytes) => match serde_json::from_slice(&body_bytes) { + Ok(deserialized) => deserialized, + Err(msg) => { self.send_http_response( - StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), + StatusCode::BAD_REQUEST.as_u16().into(), vec![], - None, - ); - error!( - "Failed to obtain body bytes even though body_size is {}", - body_size + Some(format!("Failed to deserialize: {}", msg).as_bytes()), ); return Action::Pause; } - }; + }, + None => { + self.send_http_response( + StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), + vec![], + None, + ); + error!( + "Failed to obtain body bytes even though body_size is {}", + body_size + ); + return Action::Pause; + } + }; let user_message = match deserialized_body .messages