Improve prompt target intent matching (#51)

This commit is contained in:
Adil Hafeez 2024-09-16 19:20:07 -07:00 committed by GitHub
parent 8565462ec4
commit 9e50957f22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 461 additions and 415 deletions

View file

@ -1,14 +1,18 @@
use crate::configuration::PromptTarget;
use open_message_format_embeddings::models::CreateEmbeddingRequest;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub create_embedding_request: CreateEmbeddingRequest,
pub prompt_target: PromptTarget,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum EmbeddingType {
Name,
Description,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorPoint {
pub id: String,
@ -21,21 +25,6 @@ pub struct StoreVectorEmbeddingsRequest {
pub points: Vec<VectorPoint>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(clippy::large_enum_variant)]
pub enum CallContext {
EmbeddingRequest(EmbeddingRequest),
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
CreateVectorCollection(String),
}
// https://api.qdrant.tech/master/api-reference/search/points
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointsRequest {
pub vector: Vec<f64>,
pub limit: i32,
pub with_payload: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointResult {
@ -45,13 +34,6 @@ pub struct SearchPointResult {
pub payload: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointsResponse {
pub result: Vec<SearchPointResult>,
pub status: String,
pub time: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParameter {
#[serde(rename = "type")]
@ -125,3 +107,18 @@ pub mod open_ai {
pub model: Option<String>,
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationRequest {
pub input: String,
pub labels: Vec<String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationResponse {
pub predicted_class: String,
pub predicted_class_score: f64,
pub scores: HashMap<String, f64>,
pub model: String,
}

View file

@ -99,8 +99,7 @@ pub struct PromptTarget {
#[serde(rename = "type")]
pub prompt_type: PromptType,
pub name: String,
pub description: Option<String>,
pub few_shot_examples: Vec<String>,
pub description: String,
pub parameters: Option<Vec<Parameter>>,
pub endpoint: Option<Endpoint>,
pub system_prompt: Option<String>,