mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
Use open-message-format to serialize and deserialize embeddings api (#18)
* Use open-message-format to serialize and deserialize embeddings api
This commit is contained in:
parent
a59c7df2a2
commit
cad38295bf
9 changed files with 1265 additions and 47 deletions
18
.github/workflows/checks.yml
vendored
18
.github/workflows/checks.yml
vendored
|
|
@ -9,6 +9,10 @@ jobs:
|
|||
steps:
|
||||
- name: Setup | Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: 'true'
|
||||
# TODO: Remove this once the repo is public
|
||||
token: ${{ secrets.ADIL_GITHUB_TOKEN }}
|
||||
- name: Setup | Rust
|
||||
run: rustup toolchain install stable --profile minimal
|
||||
- name: Run Clippy
|
||||
|
|
@ -21,20 +25,28 @@ jobs:
|
|||
steps:
|
||||
- name: Setup | Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: 'true'
|
||||
# TODO: Remove this once the repo is public
|
||||
token: ${{ secrets.ADIL_GITHUB_TOKEN }}
|
||||
- name: Setup | Rust
|
||||
run: rustup toolchain install stable --profile minimal
|
||||
- name: Run Rustfmt
|
||||
run: cd envoyfilter && cargo fmt --all -- --check
|
||||
|
||||
run: cd envoyfilter && cargo fmt -p intelligent-prompt-gateway -- --check
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Setup | Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: 'true'
|
||||
# TODO: Remove this once the repo is public
|
||||
token: ${{ secrets.ADIL_GITHUB_TOKEN }}
|
||||
- name: Setup | Rust
|
||||
run: rustup toolchain install stable --profile minimal
|
||||
- name: Run Tests
|
||||
# --lib is to only test the library, since when integration tests are made,
|
||||
# --lib is to only test the library, since when integration tests are made,
|
||||
# they will be in a seperate tests directory
|
||||
run: cd envoyfilter && cargo test --lib
|
||||
|
|
|
|||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
[submodule "open-message-format"]
|
||||
path = open-message-format
|
||||
url = git@github.com:open-llm-initiative/open-message-format.git
|
||||
|
|
@ -41,8 +41,13 @@ async def embedding(req: EmbeddingRequest, res: Response):
|
|||
"index": len(data)
|
||||
})
|
||||
|
||||
usage = {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
return {
|
||||
"data": data,
|
||||
"model": req.model,
|
||||
"object": "list"
|
||||
"object": "list",
|
||||
"usage": usage
|
||||
}
|
||||
|
|
|
|||
1222
envoyfilter/Cargo.lock
generated
1222
envoyfilter/Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -14,3 +14,4 @@ serde = { version = "1.0", features = ["derive"] }
|
|||
serde_yaml = "0.9.34"
|
||||
serde_json = "1.0"
|
||||
md5 = "0.7.0"
|
||||
open-message-format = { path = "../open-message-format/clients/omf-rust" }
|
||||
|
|
|
|||
|
|
@ -1,29 +1,10 @@
|
|||
use open_message_format::models::CreateEmbeddingRequest;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::configuration;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingRequest {
|
||||
pub input: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingResponse {
|
||||
pub object: String,
|
||||
pub model: String,
|
||||
pub data: Vec<Embedding>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Embedding {
|
||||
pub object: String,
|
||||
pub index: i32,
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
pub create_embedding_request: CreateEmbeddingRequest,
|
||||
|
|
@ -45,7 +26,7 @@ pub struct CalloutData {
|
|||
pub struct VectorPoint {
|
||||
pub id: String,
|
||||
pub payload: HashMap<String, String>,
|
||||
pub vector: Vec<f32>,
|
||||
pub vector: Vec<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
use common_types::CreateEmbeddingRequest;
|
||||
use common_types::EmbeddingRequest;
|
||||
use configuration::PromptTarget;
|
||||
use log::info;
|
||||
use md5::Digest;
|
||||
use open_message_format::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use serde_json::to_string;
|
||||
use stats::IncrementingMetric;
|
||||
use stats::Metric;
|
||||
|
|
@ -149,9 +152,15 @@ impl FilterContext {
|
|||
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
|
||||
for few_shot_example in &prompt_target.few_shot_examples {
|
||||
info!("few_shot_example: {:?}", few_shot_example);
|
||||
let embeddings_input = common_types::CreateEmbeddingRequest {
|
||||
input: few_shot_example.to_string(),
|
||||
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(
|
||||
few_shot_example.to_string(),
|
||||
)),
|
||||
model: String::from(consts::DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
// TODO: Handle potential errors
|
||||
|
|
@ -198,7 +207,7 @@ impl FilterContext {
|
|||
info!("response received for CreateEmbeddingRequest");
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
let embedding_response: common_types::CreateEmbeddingResponse =
|
||||
let embedding_response: CreateEmbeddingResponse =
|
||||
serde_json::from_slice(&body).unwrap();
|
||||
info!(
|
||||
"embedding_response model: {}, vector len: {}",
|
||||
|
|
@ -211,16 +220,18 @@ impl FilterContext {
|
|||
"prompt-target".to_string(),
|
||||
to_string(&prompt_target).unwrap(),
|
||||
);
|
||||
payload.insert(
|
||||
"few-shot-example".to_string(),
|
||||
create_embedding_request.input.clone(),
|
||||
);
|
||||
|
||||
let id = md5::compute(create_embedding_request.input);
|
||||
let id: Option<Digest>;
|
||||
match *create_embedding_request.input {
|
||||
CreateEmbeddingRequestInput::String(input) => {
|
||||
id = Some(md5::compute(&input));
|
||||
payload.insert("input".to_string(), input);
|
||||
}
|
||||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = common_types::CreateVectorStorePoints {
|
||||
points: vec![common_types::VectorPoint {
|
||||
id: format!("{:x}", id),
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
}],
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@
|
|||
"name": "embedding-server",
|
||||
"path": "embedding-server"
|
||||
},
|
||||
{
|
||||
"name": "open-message-format",
|
||||
"path": "open-message-format"
|
||||
},
|
||||
{
|
||||
"name": "demos",
|
||||
"path": "./demos"
|
||||
|
|
|
|||
1
open-message-format
Submodule
1
open-message-format
Submodule
|
|
@ -0,0 +1 @@
|
|||
Subproject commit f70179ae3de7110cb40412c902b275b7900b40a1
|
||||
Loading…
Add table
Add a link
Reference in a new issue