Improve prompt target intent matching (#51)

This commit is contained in:
Adil Hafeez 2024-09-16 19:20:07 -07:00 committed by GitHub
parent 8565462ec4
commit 9e50957f22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 461 additions and 415 deletions

View file

@ -71,13 +71,6 @@ services:
profiles:
- manual
qdrant:
image: qdrant/qdrant
hostname: vector-db
ports:
- 16333:6333
- 16334:6334
chatbot-ui:
build:
context: ../../chatbot-ui

View file

@ -1,6 +1,7 @@
import os
import sentence_transformers
from gliner import GLiNER
from transformers import pipeline
def load_transformers(models = os.getenv("MODELS", "BAAI/bge-large-en-v1.5")):
transformers = {}
@ -17,3 +18,11 @@ def load_ner_models(models = os.getenv("NER_MODELS", "urchade/gliner_large-v2.1"
ner_models[model] = GLiNER.from_pretrained(model)
return ner_models
def load_zero_shot_models(models = os.getenv("ZERO_SHOT_MODELS", "tasksource/deberta-base-long-nli")):
zero_shot_models = {}
for model in models.split(','):
zero_shot_models[model] = pipeline("zero-shot-classification",model=model)
return zero_shot_models

View file

@ -1,11 +1,13 @@
import random
from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel
from load_models import load_ner_models, load_transformers
from load_models import load_ner_models, load_transformers, load_zero_shot_models
from datetime import date, timedelta
import string
transformers = load_transformers()
ner_models = load_ner_models()
zero_shot_models = load_zero_shot_models()
app = FastAPI()
@ -81,6 +83,42 @@ async def ner(req: NERRequest, res: Response):
"object": "list",
}
class ZeroShotRequest(BaseModel):
input: str
labels: list[str]
model: str
def remove_punctuations(s, lower=True):
s = s.translate(str.maketrans(string.punctuation, " " * len(string.punctuation)))
s = " ".join(s.split())
if lower:
s = s.lower()
return s
@app.post("/zeroshot")
async def zeroshot(req: ZeroShotRequest, res: Response):
if req.model not in zero_shot_models:
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
classifier = zero_shot_models[req.model]
labels_without_punctuations = [remove_punctuations(label) for label in req.labels]
predicted_classes = classifier(req.input, candidate_labels=labels_without_punctuations, multi_label=True)
label_map = dict(zip(labels_without_punctuations, req.labels))
orig_map = [label_map[label] for label in predicted_classes["labels"]]
final_scores = dict(zip(orig_map, predicted_classes["scores"]))
predicted_class = label_map[predicted_classes["labels"][0]]
return {
"predicted_class": predicted_class,
"predicted_class_score": final_scores[predicted_class],
"scores": final_scores,
"model": req.model,
}
class WeatherRequest(BaseModel):
city: str

97
envoyfilter/Cargo.lock generated
View file

@ -2,6 +2,15 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "acap"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6851a0b3b2d5729a0b7e61e3c36923ed9d72240146b0efda61121b0b84ad595d"
dependencies = [
"num-traits",
]
[[package]]
name = "addr2line"
version = "0.21.0"
@ -13,18 +22,18 @@ dependencies = [
[[package]]
name = "addr2line"
version = "0.22.0"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678"
checksum = "f5fb1d8e4442bd405fdfd1dacb42792696b0cf9cb15882e5d097b742a676d375"
dependencies = [
"gimli 0.29.0",
"gimli 0.31.0",
]
[[package]]
name = "adler"
version = "1.0.2"
name = "adler2"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
[[package]]
name = "ahash"
@ -116,17 +125,17 @@ checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "backtrace"
version = "0.3.73"
version = "0.3.74"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a"
checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a"
dependencies = [
"addr2line 0.22.0",
"cc",
"addr2line 0.24.1",
"cfg-if 1.0.0",
"libc",
"miniz_oxide",
"object",
"rustc-demangle",
"windows-targets",
]
[[package]]
@ -208,9 +217,9 @@ checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50"
[[package]]
name = "cc"
version = "1.1.17"
version = "1.1.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a93fe60e2fc87b6ba2c117f67ae14f66e3fc7d6a1e612a25adb238cc980eadb3"
checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476"
dependencies = [
"jobserver",
"libc",
@ -718,9 +727,9 @@ dependencies = [
[[package]]
name = "gimli"
version = "0.29.0"
version = "0.31.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd"
checksum = "32085ea23f3234fc7846555e85283ba4de91e21016dc0455a16286d87a292d64"
[[package]]
name = "governor"
@ -910,9 +919,9 @@ dependencies = [
[[package]]
name = "hyper-util"
version = "0.1.7"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9"
checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba"
dependencies = [
"bytes",
"futures-channel",
@ -959,6 +968,7 @@ dependencies = [
name = "intelligent-prompt-gateway"
version = "0.1.0"
dependencies = [
"acap",
"governor",
"http",
"log",
@ -976,9 +986,9 @@ dependencies = [
[[package]]
name = "ipnet"
version = "2.9.0"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4"
[[package]]
name = "itertools"
@ -1137,11 +1147,11 @@ dependencies = [
[[package]]
name = "miniz_oxide"
version = "0.7.4"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08"
checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1"
dependencies = [
"adler",
"adler2",
]
[[package]]
@ -1194,6 +1204,15 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
name = "object"
version = "0.36.4"
@ -1648,9 +1667,9 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.12"
version = "0.23.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044"
checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8"
dependencies = [
"once_cell",
"rustls-pki-types",
@ -1677,9 +1696,9 @@ checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0"
[[package]]
name = "rustls-webpki"
version = "0.102.7"
version = "0.102.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "84678086bd54edf2b415183ed7a94d0efb049f1b646a33e22a36f3794be6ae56"
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
dependencies = [
"ring",
"rustls-pki-types",
@ -1694,20 +1713,20 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "scc"
version = "2.1.16"
version = "2.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aeb7ac86243095b70a7920639507b71d51a63390d1ba26c4f60a552fbb914a37"
checksum = "0c947adb109a8afce5fc9c7bf951f87f146e9147b3a6a58413105628fb1d1e66"
dependencies = [
"sdd",
]
[[package]]
name = "schannel"
version = "0.1.23"
version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534"
checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b"
dependencies = [
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@ -1718,9 +1737,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "sdd"
version = "3.0.2"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0495e4577c672de8254beb68d01a9b62d0e8a13c099edecdbedccce3223cd29f"
checksum = "60a7b59a5d9b0099720b417b6325d91a52cbf5b3dcb5041d864be53eefa58abc"
[[package]]
name = "security-framework"
@ -2431,9 +2450,9 @@ dependencies = [
[[package]]
name = "wasm-encoder"
version = "0.216.0"
version = "0.217.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04c23aebea22c8a75833ae08ed31ccc020835b12a41999e58c31464271b94a88"
checksum = "7b88b0814c9a2b323a9b46c687e726996c255ac8b64aa237dd11c81ed4854760"
dependencies = [
"leb128",
]
@ -2721,22 +2740,22 @@ dependencies = [
[[package]]
name = "wast"
version = "216.0.0"
version = "217.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7eb1f2eecd913fdde0dc6c3439d0f24530a98ac6db6cb3d14d92a5328554a08"
checksum = "79004ecebded92d3c710d4841383368c7f04b63d0992ddd6b0c7d5029b7629b7"
dependencies = [
"bumpalo",
"leb128",
"memchr",
"unicode-width",
"wasm-encoder 0.216.0",
"wasm-encoder 0.217.0",
]
[[package]]
name = "wat"
version = "1.216.0"
version = "1.217.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac0409090fb5154f95fb5ba3235675fd9e579e731524d63b6a2f653e1280c82a"
checksum = "c126271c3d92ca0f7c63e4e462e40c69cca52fd4245fcda730d1cf558fb55088"
dependencies = [
"wast",
]

View file

@ -19,6 +19,7 @@ public-types = { path = "../public-types" }
http = "1.1.0"
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
tiktoken-rs = "0.5.9"
acap = "0.3.0"
[dev-dependencies]
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" }

View file

@ -14,5 +14,6 @@ RUN cargo build --release --target wasm32-wasi
FROM envoyproxy/envoy:v1.30-latest
COPY --from=builder /envoyfilter/target/wasm32-wasi/release/intelligent_prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm
COPY envoyfilter/envoy.yaml /etc/envoy.yaml
CMD ["envoy", "-c", "/etc/envoy/envoy.yaml"]
# CMD ["envoy", "-c", "/etc/envoy/envoy.yaml"]
# CMD ["envoy", "-c", "/etc/envoy/envoy.yaml", "--log-level", "debug"]
CMD ["envoy", "-c", "/etc/envoy/envoy.yaml", "--component-log-level", "wasm:debug"]

View file

@ -165,26 +165,12 @@ static_resources:
address: embeddingserver
port_value: 80
hostname: "embeddingserver"
- name: qdrant
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: qdrant
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: qdrant
port_value: 6333
hostname: "qdrant"
- name: mistral_7b_instruct
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: qdrant
cluster_name: mistral_7b_instruct
endpoints:
- lb_endpoints:
- endpoint:

View file

@ -1,6 +1,6 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6;
pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector";
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";

View file

@ -3,22 +3,20 @@ use crate::ratelimit;
use crate::stats::{Counter, Gauge, RecordingMetric};
use crate::stream_context::StreamContext;
use log::debug;
use md5::Digest;
use open_message_format_embeddings::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::{
CallContext, EmbeddingRequest, StoreVectorEmbeddingsRequest, VectorPoint,
};
use public_types::common_types::EmbeddingType;
use public_types::configuration::{Configuration, PromptTarget};
use serde_json::to_string;
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::{OnceLock, RwLock};
use std::time::Duration;
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
pub struct WasmMetrics {
pub active_http_calls: Gauge,
pub ratelimited_rq: Counter,
@ -33,11 +31,29 @@ impl WasmMetrics {
}
}
#[derive(Debug)]
struct CallContext {
prompt_target: String,
embedding_type: EmbeddingType,
}
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
#[derive(Debug)]
pub struct FilterContext {
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, CallContext>,
config: Option<Configuration>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
}
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
static EMBEDDINGS: OnceLock<RwLock<HashMap<String, EmbeddingTypeMap>>> = OnceLock::new();
EMBEDDINGS.get_or_init(|| {
let embeddings: HashMap<String, EmbeddingTypeMap> = HashMap::new();
RwLock::new(embeddings)
})
}
impl FilterContext {
@ -46,75 +62,95 @@ impl FilterContext {
callouts: HashMap::new(),
config: None,
metrics: Rc::new(WasmMetrics::new()),
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
}
}
fn process_prompt_targets(&mut self) {
for prompt_target in &self.config.as_ref().unwrap().prompt_targets {
for few_shot_example in &prompt_target.few_shot_examples {
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(
few_shot_example.to_string(),
)),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
// TODO: Handle potential errors
let json_data: String = to_string(&embeddings_input).unwrap();
let token_id = match self.dispatch_http_call(
"embeddingserver",
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "20000"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
let embedding_request = EmbeddingRequest {
create_embedding_request: embeddings_input,
// Need to clone prompt target to leave config string intact.
prompt_target: prompt_target.clone(),
};
debug!(
"dispatched HTTP call to embedding server token_id={}",
token_id
);
if self
.callouts
.insert(token_id, {
CallContext::EmbeddingRequest(embedding_request)
})
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
let prompt_targets = match self.prompt_targets.read() {
Ok(prompt_targets) => prompt_targets,
Err(e) => {
panic!("Error reading prompt targets: {:?}", e);
}
};
for values in prompt_targets.iter() {
let prompt_target = &values.1;
// schedule embeddings call for prompt target name
let token_id = self.schedule_embeddings_call(prompt_target.name.clone());
if self
.callouts
.insert(token_id, {
CallContext {
prompt_target: prompt_target.name.clone(),
embedding_type: EmbeddingType::Name,
}
})
.is_some()
{
panic!("duplicate token_id")
}
// schedule embeddings call for prompt target description
let token_id = self.schedule_embeddings_call(prompt_target.description.clone());
if self
.callouts
.insert(token_id, {
CallContext {
prompt_target: prompt_target.name.clone(),
embedding_type: EmbeddingType::Description,
}
})
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
fn embedding_request_handler(
fn schedule_embeddings_call(&self, input: String) -> u32 {
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(input)),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let json_data = to_string(&embeddings_input).unwrap();
let token_id = match self.dispatch_http_call(
"embeddingserver",
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(60),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
token_id
}
fn embedding_response_handler(
&mut self,
body_size: usize,
create_embedding_request: CreateEmbeddingRequest,
prompt_target: PromptTarget,
embedding_type: EmbeddingType,
prompt_target_name: String,
) {
let prompt_targets = self.prompt_targets.read().unwrap();
let prompt_target = prompt_targets.get(&prompt_target_name).unwrap();
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
let mut embedding_response: CreateEmbeddingResponse =
@ -129,111 +165,24 @@ impl FilterContext {
}
};
let mut payload: HashMap<String, String> = HashMap::new();
payload.insert(
"prompt-target".to_string(),
to_string(&prompt_target).unwrap(),
let embeddings = embedding_response.data.remove(0).embedding;
log::info!(
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
prompt_target.name,
prompt_target.description,
embedding_type
);
let id: Option<Digest>;
match *create_embedding_request.input {
CreateEmbeddingRequestInput::String(input) => {
id = Some(md5::compute(&input));
payload.insert("input".to_string(), input);
}
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = StoreVectorEmbeddingsRequest {
points: vec![VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
vector: embedding_response.data.remove(0).embedding,
}],
};
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store/points"),
(":authority", "qdrant"),
("content-type", "application/json"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
if self
.callouts
.insert(
token_id,
CallContext::StoreVectorEmbeddings(create_vector_store_points),
)
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
embeddings_store().write().unwrap().insert(
prompt_target.name.clone(),
HashMap::from([(embedding_type, embeddings)]),
);
}
} else {
panic!("No body in response");
}
}
fn create_vector_store_points_handler(&self, body_size: usize) {
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
debug!(
"response body: len {:?}",
String::from_utf8(body).unwrap().len()
);
}
}
}
//TODO: run once per envoy instance, right now it runs once per worker
fn init_vector_store(&mut self) {
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store"),
(":authority", "qdrant"),
("content-type", "application/json"),
],
Some(b"{ \"vectors\": { \"size\": 1024, \"distance\": \"Cosine\"}}"),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for init-vector-store: {:?}", e);
}
};
if self
.callouts
.insert(
token_id,
CallContext::CreateVectorCollection("prompt_vector_store".to_string()),
)
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
impl Context for FilterContext {
fn on_http_call_response(
&mut self,
@ -242,37 +191,18 @@ impl Context for FilterContext {
body_size: usize,
_num_trailers: usize,
) {
let callout_data = self
.callouts
.remove(&token_id)
.expect("invalid token_id: {}");
debug!("on_http_call_response called with token_id: {:?}", token_id);
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
match callout_data {
CallContext::EmbeddingRequest(EmbeddingRequest {
create_embedding_request,
prompt_target,
}) => {
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
}
CallContext::StoreVectorEmbeddings(_) => {
self.create_vector_store_points_handler(body_size)
}
CallContext::CreateVectorCollection(_) => {
let mut http_status_code = "Nil".to_string();
self.get_http_call_response_headers()
.iter()
.for_each(|(k, v)| {
if k == ":status" {
http_status_code.clone_from(v);
}
});
debug!("CreateVectorCollection response: {}", http_status_code);
}
}
self.embedding_response_handler(
body_size,
callout_data.embedding_type,
callout_data.prompt_target,
)
}
}
@ -282,6 +212,13 @@ impl RootContext for FilterContext {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
for pt in self.config.clone().unwrap().prompt_targets {
self.prompt_targets
.write()
.unwrap()
.insert(pt.name.clone(), pt.clone());
}
debug!("set configuration object: {:?}", self.config);
if let Some(ratelimits_config) = self
@ -301,6 +238,7 @@ impl RootContext for FilterContext {
ratelimit_selector: None,
callouts: HashMap::new(),
metrics: Rc::clone(&self.metrics),
prompt_targets: Rc::clone(&self.prompt_targets),
}))
}
@ -314,7 +252,6 @@ impl RootContext for FilterContext {
}
fn on_tick(&mut self) {
self.init_vector_store();
self.process_prompt_targets();
self.set_tick_period(Duration::from_secs(0));
}

View file

@ -33,7 +33,7 @@ pub trait RecordingMetric: Metric {
}
}
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
pub struct Counter {
id: u32,
}
@ -55,7 +55,7 @@ impl Metric for Counter {
impl IncrementingMetric for Counter {}
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
pub struct Gauge {
id: u32,
}

View file

@ -1,13 +1,14 @@
use crate::consts::{
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL,
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE,
USER_ROLE,
};
use crate::filter_context::WasmMetrics;
use crate::filter_context::{embeddings_store, WasmMetrics};
use crate::ratelimit;
use crate::ratelimit::Header;
use crate::stats::IncrementingMetric;
use crate::tokenizer;
use acap::cos;
use http::StatusCode;
use log::{debug, error, info, warn};
use open_message_format_embeddings::models::{
@ -15,31 +16,31 @@ use open_message_format_embeddings::models::{
};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::open_ai::{ChatCompletions, Message};
use public_types::common_types::{
open_ai::{ChatCompletions, Message},
SearchPointsRequest, SearchPointsResponse,
};
use public_types::common_types::{
BoltFCResponse, BoltFCToolsCall, ToolParameter, ToolParameters, ToolsDefinition,
BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition,
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
};
use public_types::configuration::{PromptTarget, PromptType};
use std::collections::HashMap;
use std::num::NonZero;
use std::rc::Rc;
use std::sync::RwLock;
use std::time::Duration;
enum RequestType {
GetEmbedding,
SearchPoints,
enum ResponseHandlerType {
GetEmbeddings,
FunctionResolver,
FunctionCallResponse,
FunctionCall,
ZeroShotIntent,
}
pub struct CallContext {
request_type: RequestType,
response_handler_type: ResponseHandlerType,
user_message: Option<String>,
prompt_target: Option<PromptTarget>,
request_body: ChatCompletions,
similarity_scores: Option<Vec<(String, f64)>>,
}
pub struct StreamContext {
@ -47,6 +48,7 @@ pub struct StreamContext {
pub ratelimit_selector: Option<Header>,
pub callouts: HashMap<u32, CallContext>,
pub metrics: Rc<WasmMetrics>,
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
}
impl StreamContext {
@ -61,7 +63,6 @@ impl StreamContext {
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
// manipulate the body in benign ways e.g., compression.
self.set_http_request_header("content-length", None);
// self.set_http_request_header("authorization", None);
}
fn modify_path_header(&mut self) {
@ -85,7 +86,7 @@ impl StreamContext {
});
}
fn send_server_error(&mut self, error: String) {
fn send_server_error(&self, error: String) {
debug!("server error occurred: {}", error);
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
@ -103,30 +104,85 @@ impl StreamContext {
}
};
let search_points_request = SearchPointsRequest {
vector: embedding_response.data[0].embedding.clone(),
limit: 10,
with_payload: true,
};
let embeddings_vector = &embedding_response.data[0].embedding;
let json_data: String = match serde_json::to_string(&search_points_request) {
Ok(json_data) => json_data,
debug!(
"embedding model: {}, vector length: {:?}",
embedding_response.model,
embeddings_vector.len()
);
let prompt_target_embeddings = match embeddings_store().read() {
Ok(embeddings) => embeddings,
Err(e) => {
self.send_server_error(format!("Error serializing search_points_request: {:?}", e));
let error_message = format!("Error reading embeddings store: {:?}", e);
warn!("{}", error_message);
self.send_server_error(error_message);
return;
}
};
let path = format!("/collections/{}/points/search", DEFAULT_COLLECTION_NAME);
let prompt_targets = match self.prompt_targets.read() {
Ok(prompt_targets) => prompt_targets,
Err(e) => {
let error_message = format!("Error reading prompt targets: {:?}", e);
warn!("{}", error_message);
self.send_server_error(error_message);
return;
}
};
let prompt_target_names = prompt_targets
.iter()
.map(|(name, _)| name.clone())
.collect();
let similarity_scores: Vec<(String, f64)> = prompt_targets
.iter()
.map(|(prompt_name, _prompt_target)| {
let default_embeddings = HashMap::new();
let pte = prompt_target_embeddings
.get(prompt_name)
.unwrap_or(&default_embeddings);
let description_embeddings = pte.get(&EmbeddingType::Description);
let similarity_score_description = cos::cosine_similarity(
&embeddings_vector,
&description_embeddings.unwrap_or(&vec![0.0]),
);
(prompt_name.clone(), similarity_score_description)
})
.collect();
debug!(
"similarity scores based on description embeddings match: {:?}",
similarity_scores
);
callout_context.similarity_scores = Some(similarity_scores);
let zero_shot_classification_request = ZeroShotClassificationRequest {
// Need to clone into input because user_message is used below.
input: callout_context.user_message.as_ref().unwrap().clone(),
model: String::from(DEFAULT_INTENT_MODEL),
labels: prompt_target_names,
};
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing zero shot request: {}", error);
}
};
let token_id = match self.dispatch_http_call(
"qdrant",
"embeddingserver",
vec![
(":method", "POST"),
(":path", &path),
(":authority", "qdrant"),
(":path", "/zeroshot"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
],
Some(json_data.as_bytes()),
vec![],
@ -134,39 +190,60 @@ impl StreamContext {
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
panic!(
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
e
);
}
};
debug!(
"dispatched HTTP call to embedding server for zero-shot-intent-detection token_id={}",
token_id
);
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
callout_context.request_type = RequestType::SearchPoints;
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
}
self.metrics.active_http_calls.increment(1);
}
fn search_points_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
let search_points_response: SearchPointsResponse = match serde_json::from_slice(&body) {
Ok(search_points_response) => search_points_response,
Err(e) => {
self.send_server_error(format!(
"Error deserializing search_points_response: {:?}",
e
));
fn zero_shot_intent_detection_resp_handler(
&mut self,
body: Vec<u8>,
mut callout_context: CallContext,
) {
let zeroshot_intent_response: ZeroShotClassificationResponse =
match serde_json::from_slice(&body) {
Ok(zeroshot_response) => zeroshot_response,
Err(e) => {
warn!(
"Error deserializing zeroshot intent detection response: {:?}",
e
);
info!("body: {:?}", String::from_utf8(body).unwrap());
self.resume_http_request();
return;
}
};
return;
}
};
debug!("zeroshot intent response: {:?}", zeroshot_intent_response);
let search_results = &search_points_response.result;
let prompt_target_similarity_score = zeroshot_intent_response.predicted_class_score * 0.7
+ callout_context.similarity_scores.as_ref().unwrap()[0].1 * 0.3;
if search_results.is_empty() {
info!("No prompt target matched");
self.resume_http_request();
return;
}
debug!(
"similarity score: {}, intent score: {}, description embedding score: {}",
prompt_target_similarity_score,
zeroshot_intent_response.predicted_class_score,
callout_context.similarity_scores.as_ref().unwrap()[0].1
);
let prompt_target_name = zeroshot_intent_response.predicted_class.clone();
info!("similarity score: {}", search_results[0].score);
// Check to see who responded to user message. This will help us identify if control should be passed to Bolt FC or not.
// If the last message was from Bolt FC, then Bolt FC is handling the conversation (possibly for parameter collection).
let mut bolt_assistant = false;
@ -175,7 +252,6 @@ impl StreamContext {
let latest_assistant_message = &messages[messages.len() - 2];
if let Some(model) = latest_assistant_message.model.as_ref() {
if model.starts_with("Bolt") {
info!("Bolt assistant message found");
bolt_assistant = true;
}
}
@ -183,23 +259,30 @@ impl StreamContext {
info!("no assistant message found, probably first interaction");
}
if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant {
info!(
"prompt target below threshold: {}",
DEFAULT_PROMPT_TARGET_THRESHOLD
);
self.resume_http_request();
return;
}
let prompt_target_str = search_results[0].payload.get("prompt-target").unwrap();
let prompt_target: PromptTarget = match serde_json::from_slice(prompt_target_str.as_bytes())
{
Ok(prompt_target) => prompt_target,
Err(e) => {
self.send_server_error(format!("Error deserializing prompt_target: {:?}", e));
// check to ensure that the prompt target similarity score is above the threshold
if prompt_target_similarity_score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant {
// if bolt fc responded to the user message, then we don't need to check the similarity score
// it may be that bolt fc is handling the conversation for parameter collection
if bolt_assistant {
info!("bolt assistant is handling the conversation");
} else {
info!(
"prompt target below threshold: {}, continue conversation with user",
prompt_target_similarity_score,
);
self.resume_http_request();
return;
}
};
}
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(&prompt_target_name)
.unwrap()
.clone();
info!(
"prompt_target name: {:?}, type: {:?}",
prompt_target.name, prompt_target.prompt_type
@ -231,7 +314,7 @@ impl StreamContext {
let tools_defintion: ToolsDefinition = ToolsDefinition {
name: prompt_target.name.clone(),
description: prompt_target.description.clone().unwrap_or("".to_string()),
description: prompt_target.description.clone(),
parameters: tools_parameters,
};
@ -283,7 +366,7 @@ impl StreamContext {
BOLT_FC_CLUSTER, token_id
);
callout_context.request_type = RequestType::FunctionResolver;
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
callout_context.prompt_target = Some(prompt_target);
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
@ -342,7 +425,7 @@ impl StreamContext {
{
warn!("boltfc did not extract required parameter: {}", param.name);
return self.send_http_response(
StatusCode::BAD_REQUEST.as_u16().into(),
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
Some("missing required parameter".as_bytes()),
);
@ -380,7 +463,7 @@ impl StreamContext {
}
};
callout_context.request_type = RequestType::FunctionCallResponse;
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
@ -468,7 +551,6 @@ impl StreamContext {
}
}
debug!("sending request to openai: msg {}", json_string);
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
self.resume_http_request();
}
@ -582,10 +664,11 @@ impl HttpContext for StreamContext {
);
let call_context = CallContext {
request_type: RequestType::GetEmbedding,
response_handler_type: ResponseHandlerType::GetEmbeddings,
user_message: Some(user_message),
prompt_target: None,
request_body: deserialized_body,
similarity_scores: None,
};
if self.callouts.insert(token_id, call_context).is_some() {
panic!(
@ -593,6 +676,7 @@ impl HttpContext for StreamContext {
token_id
)
}
self.metrics.active_http_calls.increment(1);
Action::Pause
@ -611,18 +695,24 @@ impl Context for StreamContext {
self.metrics.active_http_calls.increment(-1);
if let Some(body) = self.get_http_call_response_body(0, body_size) {
match callout_context.request_type {
RequestType::GetEmbedding => self.embeddings_handler(body, callout_context),
RequestType::SearchPoints => self.search_points_handler(body, callout_context),
RequestType::FunctionResolver => {
match callout_context.response_handler_type {
ResponseHandlerType::GetEmbeddings => {
self.embeddings_handler(body, callout_context)
}
ResponseHandlerType::FunctionResolver => {
self.function_resolver_handler(body, callout_context)
}
RequestType::FunctionCallResponse => {
ResponseHandlerType::FunctionCall => {
self.function_call_response_handler(body, callout_context)
}
ResponseHandlerType::ZeroShotIntent => {
self.zero_shot_intent_detection_resp_handler(body, callout_context)
}
}
} else {
warn!("No response body in inline HTTP request");
let error_message = "No response body in inline HTTP request";
warn!("{}", error_message);
self.send_server_error(error_message.to_owned());
}
}
}

View file

@ -8,16 +8,10 @@ use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
};
use public_types::{
common_types::{
open_ai::Message, BoltFCResponse, BoltFCToolsCall, IntOrString, ToolCallDetail,
},
configuration::{self, Endpoint, PromptTarget},
};
use public_types::{
common_types::{SearchPointResult, SearchPointsResponse},
configuration::Configuration,
use public_types::common_types::{
open_ai::Message, BoltFCResponse, BoltFCToolsCall, IntOrString, ToolCallDetail,
};
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use serial_test::serial;
use std::collections::HashMap;
use std::path::Path;
@ -118,59 +112,39 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embeddings_response_buffer))
.expect_http_call(Some("qdrant"), None, None, None, None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("embeddingserver"), None, None, None, None)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let prompt_target = PromptTarget {
name: String::from("test-prompt-target"),
description: None,
prompt_type: configuration::PromptType::FunctionResolver,
few_shot_examples: vec![],
parameters: Some(vec![configuration::Parameter {
name: String::from("test-entity"),
parameter_type: Some(String::from("string")),
description: String::from("test-description"),
required: Some(true),
}]),
endpoint: Some(Endpoint {
cluster: String::from("test-endpoint-cluster"),
path: None,
method: None,
}),
system_prompt: None,
let zero_shot_response = ZeroShotClassificationResponse {
predicted_class: "weather_forecast".to_string(),
predicted_class_score: 0.1,
scores: HashMap::new(),
model: "test-model".to_string(),
};
let prompt_target_str = serde_json::to_string(&prompt_target).unwrap();
let search_points_response = SearchPointsResponse {
status: String::new(),
time: 0.0,
result: vec![SearchPointResult {
id: String::new(),
version: 0,
score: 0.7,
payload: HashMap::from([(String::from("prompt-target"), prompt_target_str)]),
}],
};
let search_points_response_buffer = serde_json::to_string(&search_points_response).unwrap();
let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
2,
0,
search_points_response_buffer.len() as i32,
zeroshot_intent_detection_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&search_points_response_buffer))
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Info), None)
.returning(Some(&zeroshot_intent_detection_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_http_call(Some("bolt_fc_1b"), None, None, None, None)
.returning(Some(3))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -199,20 +173,22 @@ system_prompt: |
prompt_targets:
- type: function_resolver
name: weather_forecast
few_shot_examples:
- what is the weather in New York?
description: This resolver provides weather forecast information.
endpoint:
cluster: weatherhost
path: /weather
entities:
- name: location
parameters:
- name: city
required: true
description: "The location for which the weather is requested"
description: The city for which the weather forecast is requested.
- name: days
description: The number of days for which the weather forecast is requested.
- name: units
description: The units in which the weather forecast is requested.
- type: function_resolver
name: weather_forecast_2
few_shot_examples:
- what is the weather in New York?
description: This resolver provides weather forecast information.
endpoint:
cluster: weatherhost
path: /weather
@ -450,10 +426,10 @@ fn request_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let tool_call_detail = vec![ToolCallDetail {
name: String::from("test-tool"),
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("test-entity"),
IntOrString::Text(String::from("test-value")),
String::from("city"),
IntOrString::Text(String::from("seattle")),
)]),
}];
@ -485,7 +461,7 @@ fn request_ratelimited() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("test-endpoint-cluster"), None, None, None, None)
.expect_http_call(Some("weatherhost"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
@ -555,10 +531,10 @@ fn request_not_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let tool_call_detail = vec![ToolCallDetail {
name: String::from("test-tool"),
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("test-entity"),
IntOrString::Text(String::from("test-value")),
String::from("city"),
IntOrString::Text(String::from("seattle")),
)]),
}];
@ -590,7 +566,7 @@ fn request_not_ratelimited() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("test-endpoint-cluster"), None, None, None, None)
.expect_http_call(Some("weatherhost"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
@ -608,7 +584,7 @@ fn request_not_ratelimited() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.expect_log(Some(LogLevel::Debug), None)
// .expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
}

View file

@ -1,14 +1,18 @@
use crate::configuration::PromptTarget;
use open_message_format_embeddings::models::CreateEmbeddingRequest;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub create_embedding_request: CreateEmbeddingRequest,
pub prompt_target: PromptTarget,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum EmbeddingType {
Name,
Description,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorPoint {
pub id: String,
@ -21,21 +25,6 @@ pub struct StoreVectorEmbeddingsRequest {
pub points: Vec<VectorPoint>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(clippy::large_enum_variant)]
pub enum CallContext {
EmbeddingRequest(EmbeddingRequest),
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
CreateVectorCollection(String),
}
// https://api.qdrant.tech/master/api-reference/search/points
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointsRequest {
pub vector: Vec<f64>,
pub limit: i32,
pub with_payload: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointResult {
@ -45,13 +34,6 @@ pub struct SearchPointResult {
pub payload: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointsResponse {
pub result: Vec<SearchPointResult>,
pub status: String,
pub time: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParameter {
#[serde(rename = "type")]
@ -125,3 +107,18 @@ pub mod open_ai {
pub model: Option<String>,
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationRequest {
pub input: String,
pub labels: Vec<String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationResponse {
pub predicted_class: String,
pub predicted_class_score: f64,
pub scores: HashMap<String, f64>,
pub model: String,
}

View file

@ -99,8 +99,7 @@ pub struct PromptTarget {
#[serde(rename = "type")]
pub prompt_type: PromptType,
pub name: String,
pub description: Option<String>,
pub few_shot_examples: Vec<String>,
pub description: String,
pub parameters: Option<Vec<Parameter>>,
pub endpoint: Option<Endpoint>,
pub system_prompt: Option<String>,