mirror of
https://github.com/katanemo/plano.git
synced 2026-05-14 02:24:00 +02:00
Move shared types into their own crate (#41)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
4dd1f3693e
commit
d98517f240
13 changed files with 1435 additions and 14 deletions
89
public-types/src/common_types.rs
Normal file
89
public-types/src/common_types.rs
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
use crate::configuration::PromptTarget;
|
||||
use open_message_format_embeddings::models::CreateEmbeddingRequest;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
pub create_embedding_request: CreateEmbeddingRequest,
|
||||
pub prompt_target: PromptTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorPoint {
|
||||
pub id: String,
|
||||
pub payload: HashMap<String, String>,
|
||||
pub vector: Vec<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StoreVectorEmbeddingsRequest {
|
||||
pub points: Vec<VectorPoint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum CallContext {
|
||||
EmbeddingRequest(EmbeddingRequest),
|
||||
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
|
||||
CreateVectorCollection(String),
|
||||
}
|
||||
|
||||
// https://api.qdrant.tech/master/api-reference/search/points
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointsRequest {
|
||||
pub vector: Vec<f64>,
|
||||
pub limit: i32,
|
||||
pub with_payload: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointResult {
|
||||
pub id: String,
|
||||
pub version: i32,
|
||||
pub score: f64,
|
||||
pub payload: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointsResponse {
|
||||
pub result: Vec<SearchPointResult>,
|
||||
pub status: String,
|
||||
pub time: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NERRequest {
|
||||
pub input: String,
|
||||
pub labels: Vec<String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Entity {
|
||||
pub text: String,
|
||||
pub label: String,
|
||||
pub score: f64,
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NERResponse {
|
||||
pub data: Vec<Entity>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
pub mod open_ai {
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletions {
|
||||
#[serde(default)]
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
}
|
||||
157
public-types/src/configuration.rs
Normal file
157
public-types/src/configuration.rs
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub default_prompt_endpoint: String,
|
||||
pub load_balancing: LoadBalancing,
|
||||
pub timeout_ms: u64,
|
||||
pub embedding_provider: EmbeddingProviver,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub prompt_targets: Vec<PromptTarget>,
|
||||
pub ratelimits: Option<Vec<Ratelimit>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Ratelimit {
|
||||
pub provider: String,
|
||||
pub selector: Header,
|
||||
pub limit: Limit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Limit {
|
||||
pub tokens: u32,
|
||||
pub unit: TimeUnit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum TimeUnit {
|
||||
#[serde(rename = "second")]
|
||||
Second,
|
||||
#[serde(rename = "minute")]
|
||||
Minute,
|
||||
#[serde(rename = "hour")]
|
||||
Hour,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub struct Header {
|
||||
pub key: String,
|
||||
pub value: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum LoadBalancing {
|
||||
#[serde(rename = "round_robin")]
|
||||
RoundRobin,
|
||||
#[serde(rename = "random")]
|
||||
Random,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct EmbeddingProviver {
|
||||
pub name: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct LlmProvider {
|
||||
pub name: String,
|
||||
pub api_key: Option<String>,
|
||||
pub model: String,
|
||||
pub default: Option<bool>,
|
||||
pub endpoint: Option<EnpointType>,
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum EnpointType {
|
||||
String(String),
|
||||
Struct(Endpoint),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Endpoint {
|
||||
pub cluster: String,
|
||||
pub path: Option<String>,
|
||||
pub method: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Entity {
|
||||
pub name: String,
|
||||
pub required: Option<bool>,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptTarget {
|
||||
#[serde(rename = "type")]
|
||||
pub prompt_type: String,
|
||||
pub name: String,
|
||||
pub few_shot_examples: Vec<String>,
|
||||
pub entities: Option<Vec<Entity>>,
|
||||
pub endpoint: Option<Endpoint>,
|
||||
pub system_prompt: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
pub const CONFIGURATION: &str = r#"
|
||||
default_prompt_endpoint: "127.0.0.1"
|
||||
load_balancing: "round_robin"
|
||||
timeout_ms: 5000
|
||||
|
||||
embedding_provider:
|
||||
name: "SentenceTransformer"
|
||||
model: "all-MiniLM-L6-v2"
|
||||
|
||||
llm_providers:
|
||||
- name: "open-ai-gpt-4"
|
||||
api_key: "$OPEN_AI_API_KEY"
|
||||
model: gpt-4
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
|
||||
prompt_targets:
|
||||
- type: context_resolver
|
||||
name: weather_forecast
|
||||
few_shot_examples:
|
||||
- what is the weather in New York?
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
entities:
|
||||
- name: location
|
||||
required: true
|
||||
description: "The location for which the weather is requested"
|
||||
|
||||
- type: context_resolver
|
||||
name: weather_forecast_2
|
||||
few_shot_examples:
|
||||
- what is the weather in New York?
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
entities:
|
||||
- name: city
|
||||
|
||||
ratelimits:
|
||||
- provider: open-ai-gpt-4
|
||||
selector:
|
||||
key: x-katanemo-openai-limit-id
|
||||
limit:
|
||||
tokens: 100
|
||||
unit: minute
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_configuration() {
|
||||
let _: super::Configuration = serde_yaml::from_str(CONFIGURATION).unwrap();
|
||||
}
|
||||
}
|
||||
2
public-types/src/lib.rs
Normal file
2
public-types/src/lib.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
pub mod common_types;
|
||||
pub mod configuration;
|
||||
Loading…
Add table
Add a link
Reference in a new issue