Use open-message-format to serialize and deserialize embeddings api (#18)

* Use open-message-format to serialize and deserialize embeddings api
This commit is contained in:
Adil Hafeez 2024-07-23 11:56:49 -07:00 committed by GitHub
parent a59c7df2a2
commit cad38295bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 1265 additions and 47 deletions

View file

@ -1,29 +1,10 @@
use open_message_format::models::CreateEmbeddingRequest;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::configuration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateEmbeddingRequest {
pub input: String,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateEmbeddingResponse {
pub object: String,
pub model: String,
pub data: Vec<Embedding>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Embedding {
pub object: String,
pub index: i32,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub create_embedding_request: CreateEmbeddingRequest,
@ -45,7 +26,7 @@ pub struct CalloutData {
pub struct VectorPoint {
pub id: String,
pub payload: HashMap<String, String>,
pub vector: Vec<f32>,
pub vector: Vec<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -1,7 +1,10 @@
use common_types::CreateEmbeddingRequest;
use common_types::EmbeddingRequest;
use configuration::PromptTarget;
use log::info;
use md5::Digest;
use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use serde_json::to_string;
use stats::IncrementingMetric;
use stats::Metric;
@ -149,9 +152,15 @@ impl FilterContext {
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
for few_shot_example in &prompt_target.few_shot_examples {
info!("few_shot_example: {:?}", few_shot_example);
let embeddings_input = common_types::CreateEmbeddingRequest {
input: few_shot_example.to_string(),
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(
few_shot_example.to_string(),
)),
model: String::from(consts::DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
// TODO: Handle potential errors
@ -198,7 +207,7 @@ impl FilterContext {
info!("response received for CreateEmbeddingRequest");
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
let embedding_response: common_types::CreateEmbeddingResponse =
let embedding_response: CreateEmbeddingResponse =
serde_json::from_slice(&body).unwrap();
info!(
"embedding_response model: {}, vector len: {}",
@ -211,16 +220,18 @@ impl FilterContext {
"prompt-target".to_string(),
to_string(&prompt_target).unwrap(),
);
payload.insert(
"few-shot-example".to_string(),
create_embedding_request.input.clone(),
);
let id = md5::compute(create_embedding_request.input);
let id: Option<Digest>;
match *create_embedding_request.input {
CreateEmbeddingRequestInput::String(input) => {
id = Some(md5::compute(&input));
payload.insert("input".to_string(), input);
}
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = common_types::CreateVectorStorePoints {
points: vec![common_types::VectorPoint {
id: format!("{:x}", id),
id: format!("{:x}", id.unwrap()),
payload,
vector: embedding_response.data[0].embedding.clone(),
}],