Clean up imports (#25)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-07-31 11:33:39 -07:00 committed by GitHub
parent 5ecdf30fdc
commit c13682a03b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 57 additions and 66 deletions

View file

@ -1,13 +1,12 @@
use crate::configuration::PromptTarget;
use open_message_format::models::CreateEmbeddingRequest;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::configuration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub create_embedding_request: CreateEmbeddingRequest,
pub prompt_target: configuration::PromptTarget,
pub prompt_target: PromptTarget,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -1,34 +1,30 @@
use common_types::{CallContext, EmbeddingRequest};
use configuration::PromptTarget;
use crate::common_types::{
CallContext, EmbeddingRequest, StoreVectorEmbeddingsRequest, VectorPoint,
};
use crate::configuration::{Configuration, PromptTarget};
use crate::consts::DEFAULT_EMBEDDING_MODEL;
use crate::stats::{Gauge, RecordingMetric};
use crate::stream_context::StreamContext;
use log::info;
use md5::Digest;
use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use serde_json::to_string;
use stats::RecordingMetric;
use std::collections::HashMap;
use std::time::Duration;
use consts::DEFAULT_EMBEDDING_MODEL;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use crate::common_types;
use crate::configuration;
use crate::consts;
use crate::stats;
use crate::stream_context::StreamContext;
#[derive(Copy, Clone)]
struct WasmMetrics {
active_http_calls: stats::Gauge,
active_http_calls: Gauge,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
WasmMetrics {
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
active_http_calls: Gauge::new(String::from("active_http_calls")),
}
}
}
@ -36,8 +32,8 @@ impl WasmMetrics {
pub struct FilterContext {
metrics: WasmMetrics,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, common_types::CallContext>,
config: Option<configuration::Configuration>,
callouts: HashMap<u32, CallContext>,
config: Option<Configuration>,
}
impl FilterContext {
@ -127,8 +123,8 @@ impl FilterContext {
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
points: vec![common_types::VectorPoint {
let create_vector_store_points = StoreVectorEmbeddingsRequest {
points: vec![VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
vector: embedding_response.data[0].embedding.clone(),
@ -231,16 +227,16 @@ impl Context for FilterContext {
.record(self.callouts.len().try_into().unwrap());
match callout_data {
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
CallContext::EmbeddingRequest(EmbeddingRequest {
create_embedding_request,
prompt_target,
}) => {
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
}
common_types::CallContext::StoreVectorEmbeddings(_) => {
CallContext::StoreVectorEmbeddings(_) => {
self.create_vector_store_points_handler(body_size)
}
common_types::CallContext::CreateVectorCollection(_) => {
CallContext::CreateVectorCollection(_) => {
let mut http_status_code = "Nil".to_string();
self.get_http_call_response_headers()
.iter()

View file

@ -1,3 +1,14 @@
use crate::common_types::{
open_ai::{ChatCompletions, Message},
NERRequest, NERResponse, SearchPointsRequest, SearchPointsResponse,
};
use crate::configuration::EntityDetail;
use crate::configuration::EntityType;
use crate::configuration::PromptTarget;
use crate::consts::{
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE,
};
use http::StatusCode;
use log::error;
use log::info;
@ -5,24 +16,10 @@ use log::warn;
use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use std::collections::HashMap;
use std::time::Duration;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use consts::{
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE,
};
use crate::common_types;
use crate::common_types::open_ai::Message;
use crate::common_types::SearchPointsResponse;
use crate::configuration::EntityDetail;
use crate::configuration::EntityType;
use crate::configuration::PromptTarget;
use crate::consts;
use std::collections::HashMap;
use std::time::Duration;
enum RequestType {
GetEmbedding,
@ -35,7 +32,7 @@ pub struct CallContext {
request_type: RequestType,
user_message: String,
prompt_target: Option<PromptTarget>,
request_body: common_types::open_ai::ChatCompletions,
request_body: ChatCompletions,
}
pub struct StreamContext {
@ -79,7 +76,7 @@ impl StreamContext {
}
};
let search_points_request = common_types::SearchPointsRequest {
let search_points_request = SearchPointsRequest {
vector: embedding_response.data[0].embedding.clone(),
limit: 10,
with_payload: true,
@ -167,7 +164,7 @@ impl StreamContext {
.map(|entity| entity.name.clone())
.collect();
let user_message = callout_context.user_message.clone();
let ner_request = common_types::NERRequest {
let ner_request = NERRequest {
input: user_message,
labels: entity_names,
model: DEFAULT_NER_MODEL.to_string(),
@ -208,7 +205,7 @@ impl StreamContext {
}
fn ner_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
let ner_response: common_types::NERResponse = match serde_json::from_slice(&body) {
let ner_response: NERResponse = match serde_json::from_slice(&body) {
Ok(ner_response) => ner_response,
Err(e) => {
warn!("Error deserializing ner_response: {:?}", e);
@ -364,32 +361,31 @@ impl HttpContext for StreamContext {
// Deserialize body into spec.
// Currently OpenAI API.
let deserialized_body: common_types::open_ai::ChatCompletions =
match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(msg) => {
self.send_http_response(
StatusCode::BAD_REQUEST.as_u16().into(),
vec![],
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
);
return Action::Pause;
}
},
None => {
let deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(msg) => {
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
StatusCode::BAD_REQUEST.as_u16().into(),
vec![],
None,
);
error!(
"Failed to obtain body bytes even though body_size is {}",
body_size
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
);
return Action::Pause;
}
};
},
None => {
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
None,
);
error!(
"Failed to obtain body bytes even though body_size is {}",
body_size
);
return Action::Pause;
}
};
let user_message = match deserialized_body
.messages