Move shared types into their own crate (#41)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-09-04 15:31:05 -07:00 committed by GitHub
parent 4dd1f3693e
commit d98517f240
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1435 additions and 14 deletions

View file

@ -0,0 +1,89 @@
use crate::configuration::PromptTarget;
use open_message_format_embeddings::models::CreateEmbeddingRequest;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub create_embedding_request: CreateEmbeddingRequest,
pub prompt_target: PromptTarget,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorPoint {
pub id: String,
pub payload: HashMap<String, String>,
pub vector: Vec<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoreVectorEmbeddingsRequest {
pub points: Vec<VectorPoint>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(clippy::large_enum_variant)]
pub enum CallContext {
EmbeddingRequest(EmbeddingRequest),
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
CreateVectorCollection(String),
}
// https://api.qdrant.tech/master/api-reference/search/points
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointsRequest {
pub vector: Vec<f64>,
pub limit: i32,
pub with_payload: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointResult {
pub id: String,
pub version: i32,
pub score: f64,
pub payload: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointsResponse {
pub result: Vec<SearchPointResult>,
pub status: String,
pub time: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NERRequest {
pub input: String,
pub labels: Vec<String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Entity {
pub text: String,
pub label: String,
pub score: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NERResponse {
pub data: Vec<Entity>,
pub model: String,
}
pub mod open_ai {
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletions {
#[serde(default)]
pub model: String,
pub messages: Vec<Message>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: Option<String>,
}
}

View file

@ -0,0 +1,157 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub default_prompt_endpoint: String,
pub load_balancing: LoadBalancing,
pub timeout_ms: u64,
pub embedding_provider: EmbeddingProviver,
pub llm_providers: Vec<LlmProvider>,
pub system_prompt: Option<String>,
pub prompt_targets: Vec<PromptTarget>,
pub ratelimits: Option<Vec<Ratelimit>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Ratelimit {
pub provider: String,
pub selector: Header,
pub limit: Limit,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Limit {
pub tokens: u32,
pub unit: TimeUnit,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TimeUnit {
#[serde(rename = "second")]
Second,
#[serde(rename = "minute")]
Minute,
#[serde(rename = "hour")]
Hour,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Header {
pub key: String,
pub value: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LoadBalancing {
#[serde(rename = "round_robin")]
RoundRobin,
#[serde(rename = "random")]
Random,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
//TODO: use enum for model, but if there is a new model, we need to update the code
pub struct EmbeddingProviver {
pub name: String,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
//TODO: use enum for model, but if there is a new model, we need to update the code
pub struct LlmProvider {
pub name: String,
pub api_key: Option<String>,
pub model: String,
pub default: Option<bool>,
pub endpoint: Option<EnpointType>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum EnpointType {
String(String),
Struct(Endpoint),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
pub cluster: String,
pub path: Option<String>,
pub method: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Entity {
pub name: String,
pub required: Option<bool>,
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptTarget {
#[serde(rename = "type")]
pub prompt_type: String,
pub name: String,
pub few_shot_examples: Vec<String>,
pub entities: Option<Vec<Entity>>,
pub endpoint: Option<Endpoint>,
pub system_prompt: Option<String>,
}
#[cfg(test)]
mod test {
pub const CONFIGURATION: &str = r#"
default_prompt_endpoint: "127.0.0.1"
load_balancing: "round_robin"
timeout_ms: 5000
embedding_provider:
name: "SentenceTransformer"
model: "all-MiniLM-L6-v2"
llm_providers:
- name: "open-ai-gpt-4"
api_key: "$OPEN_AI_API_KEY"
model: gpt-4
system_prompt: |
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
prompt_targets:
- type: context_resolver
name: weather_forecast
few_shot_examples:
- what is the weather in New York?
endpoint:
cluster: weatherhost
path: /weather
entities:
- name: location
required: true
description: "The location for which the weather is requested"
- type: context_resolver
name: weather_forecast_2
few_shot_examples:
- what is the weather in New York?
endpoint:
cluster: weatherhost
path: /weather
entities:
- name: city
ratelimits:
- provider: open-ai-gpt-4
selector:
key: x-katanemo-openai-limit-id
limit:
tokens: 100
unit: minute
"#;
#[test]
fn test_deserialize_configuration() {
let _: super::Configuration = serde_yaml::from_str(CONFIGURATION).unwrap();
}
}

2
public-types/src/lib.rs Normal file
View file

@ -0,0 +1,2 @@
pub mod common_types;
pub mod configuration;