mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Clean up imports (#25)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
5ecdf30fdc
commit
c13682a03b
3 changed files with 57 additions and 66 deletions
|
|
@ -1,13 +1,12 @@
|
|||
use crate::configuration::PromptTarget;
|
||||
use open_message_format::models::CreateEmbeddingRequest;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::configuration;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
pub create_embedding_request: CreateEmbeddingRequest,
|
||||
pub prompt_target: configuration::PromptTarget,
|
||||
pub prompt_target: PromptTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
|
|||
|
|
@ -1,34 +1,30 @@
|
|||
use common_types::{CallContext, EmbeddingRequest};
|
||||
use configuration::PromptTarget;
|
||||
use crate::common_types::{
|
||||
CallContext, EmbeddingRequest, StoreVectorEmbeddingsRequest, VectorPoint,
|
||||
};
|
||||
use crate::configuration::{Configuration, PromptTarget};
|
||||
use crate::consts::DEFAULT_EMBEDDING_MODEL;
|
||||
use crate::stats::{Gauge, RecordingMetric};
|
||||
use crate::stream_context::StreamContext;
|
||||
use log::info;
|
||||
use md5::Digest;
|
||||
use open_message_format::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use serde_json::to_string;
|
||||
use stats::RecordingMetric;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use consts::DEFAULT_EMBEDDING_MODEL;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
use crate::common_types;
|
||||
use crate::configuration;
|
||||
use crate::consts;
|
||||
use crate::stats;
|
||||
use crate::stream_context::StreamContext;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct WasmMetrics {
|
||||
active_http_calls: stats::Gauge,
|
||||
active_http_calls: Gauge,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
WasmMetrics {
|
||||
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -36,8 +32,8 @@ impl WasmMetrics {
|
|||
pub struct FilterContext {
|
||||
metrics: WasmMetrics,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, common_types::CallContext>,
|
||||
config: Option<configuration::Configuration>,
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
config: Option<Configuration>,
|
||||
}
|
||||
|
||||
impl FilterContext {
|
||||
|
|
@ -127,8 +123,8 @@ impl FilterContext {
|
|||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
|
||||
points: vec![common_types::VectorPoint {
|
||||
let create_vector_store_points = StoreVectorEmbeddingsRequest {
|
||||
points: vec![VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
|
|
@ -231,16 +227,16 @@ impl Context for FilterContext {
|
|||
.record(self.callouts.len().try_into().unwrap());
|
||||
|
||||
match callout_data {
|
||||
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
|
||||
CallContext::EmbeddingRequest(EmbeddingRequest {
|
||||
create_embedding_request,
|
||||
prompt_target,
|
||||
}) => {
|
||||
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
|
||||
}
|
||||
common_types::CallContext::StoreVectorEmbeddings(_) => {
|
||||
CallContext::StoreVectorEmbeddings(_) => {
|
||||
self.create_vector_store_points_handler(body_size)
|
||||
}
|
||||
common_types::CallContext::CreateVectorCollection(_) => {
|
||||
CallContext::CreateVectorCollection(_) => {
|
||||
let mut http_status_code = "Nil".to_string();
|
||||
self.get_http_call_response_headers()
|
||||
.iter()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,14 @@
|
|||
use crate::common_types::{
|
||||
open_ai::{ChatCompletions, Message},
|
||||
NERRequest, NERResponse, SearchPointsRequest, SearchPointsResponse,
|
||||
};
|
||||
use crate::configuration::EntityDetail;
|
||||
use crate::configuration::EntityType;
|
||||
use crate::configuration::PromptTarget;
|
||||
use crate::consts::{
|
||||
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use http::StatusCode;
|
||||
use log::error;
|
||||
use log::info;
|
||||
|
|
@ -5,24 +16,10 @@ use log::warn;
|
|||
use open_message_format::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
use consts::{
|
||||
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
|
||||
use crate::common_types;
|
||||
use crate::common_types::open_ai::Message;
|
||||
use crate::common_types::SearchPointsResponse;
|
||||
use crate::configuration::EntityDetail;
|
||||
use crate::configuration::EntityType;
|
||||
use crate::configuration::PromptTarget;
|
||||
use crate::consts;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
enum RequestType {
|
||||
GetEmbedding,
|
||||
|
|
@ -35,7 +32,7 @@ pub struct CallContext {
|
|||
request_type: RequestType,
|
||||
user_message: String,
|
||||
prompt_target: Option<PromptTarget>,
|
||||
request_body: common_types::open_ai::ChatCompletions,
|
||||
request_body: ChatCompletions,
|
||||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
|
|
@ -79,7 +76,7 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let search_points_request = common_types::SearchPointsRequest {
|
||||
let search_points_request = SearchPointsRequest {
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
limit: 10,
|
||||
with_payload: true,
|
||||
|
|
@ -167,7 +164,7 @@ impl StreamContext {
|
|||
.map(|entity| entity.name.clone())
|
||||
.collect();
|
||||
let user_message = callout_context.user_message.clone();
|
||||
let ner_request = common_types::NERRequest {
|
||||
let ner_request = NERRequest {
|
||||
input: user_message,
|
||||
labels: entity_names,
|
||||
model: DEFAULT_NER_MODEL.to_string(),
|
||||
|
|
@ -208,7 +205,7 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
fn ner_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
let ner_response: common_types::NERResponse = match serde_json::from_slice(&body) {
|
||||
let ner_response: NERResponse = match serde_json::from_slice(&body) {
|
||||
Ok(ner_response) => ner_response,
|
||||
Err(e) => {
|
||||
warn!("Error deserializing ner_response: {:?}", e);
|
||||
|
|
@ -364,32 +361,31 @@ impl HttpContext for StreamContext {
|
|||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let deserialized_body: common_types::open_ai::ChatCompletions =
|
||||
match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(msg) => {
|
||||
self.send_http_response(
|
||||
StatusCode::BAD_REQUEST.as_u16().into(),
|
||||
vec![],
|
||||
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
let deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(msg) => {
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
StatusCode::BAD_REQUEST.as_u16().into(),
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
error!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
},
|
||||
None => {
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
error!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
let user_message = match deserialized_body
|
||||
.messages
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue