mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Improve prompt target intent matching (#51)
This commit is contained in:
parent
8565462ec4
commit
9e50957f22
14 changed files with 461 additions and 415 deletions
|
|
@ -71,13 +71,6 @@ services:
|
|||
profiles:
|
||||
- manual
|
||||
|
||||
qdrant:
|
||||
image: qdrant/qdrant
|
||||
hostname: vector-db
|
||||
ports:
|
||||
- 16333:6333
|
||||
- 16334:6334
|
||||
|
||||
chatbot-ui:
|
||||
build:
|
||||
context: ../../chatbot-ui
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import sentence_transformers
|
||||
from gliner import GLiNER
|
||||
from transformers import pipeline
|
||||
|
||||
def load_transformers(models = os.getenv("MODELS", "BAAI/bge-large-en-v1.5")):
|
||||
transformers = {}
|
||||
|
|
@ -17,3 +18,11 @@ def load_ner_models(models = os.getenv("NER_MODELS", "urchade/gliner_large-v2.1"
|
|||
ner_models[model] = GLiNER.from_pretrained(model)
|
||||
|
||||
return ner_models
|
||||
|
||||
def load_zero_shot_models(models = os.getenv("ZERO_SHOT_MODELS", "tasksource/deberta-base-long-nli")):
|
||||
zero_shot_models = {}
|
||||
|
||||
for model in models.split(','):
|
||||
zero_shot_models[model] = pipeline("zero-shot-classification",model=model)
|
||||
|
||||
return zero_shot_models
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
import random
|
||||
from fastapi import FastAPI, Response, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from load_models import load_ner_models, load_transformers
|
||||
from load_models import load_ner_models, load_transformers, load_zero_shot_models
|
||||
from datetime import date, timedelta
|
||||
import string
|
||||
|
||||
transformers = load_transformers()
|
||||
ner_models = load_ner_models()
|
||||
zero_shot_models = load_zero_shot_models()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
|
@ -81,6 +83,42 @@ async def ner(req: NERRequest, res: Response):
|
|||
"object": "list",
|
||||
}
|
||||
|
||||
class ZeroShotRequest(BaseModel):
|
||||
input: str
|
||||
labels: list[str]
|
||||
model: str
|
||||
|
||||
|
||||
def remove_punctuations(s, lower=True):
|
||||
s = s.translate(str.maketrans(string.punctuation, " " * len(string.punctuation)))
|
||||
s = " ".join(s.split())
|
||||
if lower:
|
||||
s = s.lower()
|
||||
return s
|
||||
|
||||
|
||||
@app.post("/zeroshot")
|
||||
async def zeroshot(req: ZeroShotRequest, res: Response):
|
||||
if req.model not in zero_shot_models:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
classifier = zero_shot_models[req.model]
|
||||
labels_without_punctuations = [remove_punctuations(label) for label in req.labels]
|
||||
predicted_classes = classifier(req.input, candidate_labels=labels_without_punctuations, multi_label=True)
|
||||
label_map = dict(zip(labels_without_punctuations, req.labels))
|
||||
|
||||
orig_map = [label_map[label] for label in predicted_classes["labels"]]
|
||||
final_scores = dict(zip(orig_map, predicted_classes["scores"]))
|
||||
predicted_class = label_map[predicted_classes["labels"][0]]
|
||||
|
||||
return {
|
||||
"predicted_class": predicted_class,
|
||||
"predicted_class_score": final_scores[predicted_class],
|
||||
"scores": final_scores,
|
||||
"model": req.model,
|
||||
}
|
||||
|
||||
|
||||
class WeatherRequest(BaseModel):
|
||||
city: str
|
||||
|
||||
|
|
|
|||
97
envoyfilter/Cargo.lock
generated
97
envoyfilter/Cargo.lock
generated
|
|
@ -2,6 +2,15 @@
|
|||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "acap"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6851a0b3b2d5729a0b7e61e3c36923ed9d72240146b0efda61121b0b84ad595d"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
version = "0.21.0"
|
||||
|
|
@ -13,18 +22,18 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
version = "0.22.0"
|
||||
version = "0.24.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678"
|
||||
checksum = "f5fb1d8e4442bd405fdfd1dacb42792696b0cf9cb15882e5d097b742a676d375"
|
||||
dependencies = [
|
||||
"gimli 0.29.0",
|
||||
"gimli 0.31.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "adler"
|
||||
version = "1.0.2"
|
||||
name = "adler2"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
|
|
@ -116,17 +125,17 @@ checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
|
|||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.73"
|
||||
version = "0.3.74"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a"
|
||||
checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a"
|
||||
dependencies = [
|
||||
"addr2line 0.22.0",
|
||||
"cc",
|
||||
"addr2line 0.24.1",
|
||||
"cfg-if 1.0.0",
|
||||
"libc",
|
||||
"miniz_oxide",
|
||||
"object",
|
||||
"rustc-demangle",
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -208,9 +217,9 @@ checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50"
|
|||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.1.17"
|
||||
version = "1.1.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a93fe60e2fc87b6ba2c117f67ae14f66e3fc7d6a1e612a25adb238cc980eadb3"
|
||||
checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476"
|
||||
dependencies = [
|
||||
"jobserver",
|
||||
"libc",
|
||||
|
|
@ -718,9 +727,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "gimli"
|
||||
version = "0.29.0"
|
||||
version = "0.31.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd"
|
||||
checksum = "32085ea23f3234fc7846555e85283ba4de91e21016dc0455a16286d87a292d64"
|
||||
|
||||
[[package]]
|
||||
name = "governor"
|
||||
|
|
@ -910,9 +919,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.7"
|
||||
version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9"
|
||||
checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
|
|
@ -959,6 +968,7 @@ dependencies = [
|
|||
name = "intelligent-prompt-gateway"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"acap",
|
||||
"governor",
|
||||
"http",
|
||||
"log",
|
||||
|
|
@ -976,9 +986,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "ipnet"
|
||||
version = "2.9.0"
|
||||
version = "2.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
|
||||
checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4"
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
|
|
@ -1137,11 +1147,11 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.7.4"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08"
|
||||
checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1"
|
||||
dependencies = [
|
||||
"adler",
|
||||
"adler2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1194,6 +1204,15 @@ version = "0.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "object"
|
||||
version = "0.36.4"
|
||||
|
|
@ -1648,9 +1667,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.12"
|
||||
version = "0.23.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044"
|
||||
checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"rustls-pki-types",
|
||||
|
|
@ -1677,9 +1696,9 @@ checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0"
|
|||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.102.7"
|
||||
version = "0.102.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "84678086bd54edf2b415183ed7a94d0efb049f1b646a33e22a36f3794be6ae56"
|
||||
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
|
|
@ -1694,20 +1713,20 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
|
|||
|
||||
[[package]]
|
||||
name = "scc"
|
||||
version = "2.1.16"
|
||||
version = "2.1.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aeb7ac86243095b70a7920639507b71d51a63390d1ba26c4f60a552fbb914a37"
|
||||
checksum = "0c947adb109a8afce5fc9c7bf951f87f146e9147b3a6a58413105628fb1d1e66"
|
||||
dependencies = [
|
||||
"sdd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.23"
|
||||
version = "0.1.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534"
|
||||
checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b"
|
||||
dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1718,9 +1737,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
|||
|
||||
[[package]]
|
||||
name = "sdd"
|
||||
version = "3.0.2"
|
||||
version = "3.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0495e4577c672de8254beb68d01a9b62d0e8a13c099edecdbedccce3223cd29f"
|
||||
checksum = "60a7b59a5d9b0099720b417b6325d91a52cbf5b3dcb5041d864be53eefa58abc"
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
|
|
@ -2431,9 +2450,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "wasm-encoder"
|
||||
version = "0.216.0"
|
||||
version = "0.217.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04c23aebea22c8a75833ae08ed31ccc020835b12a41999e58c31464271b94a88"
|
||||
checksum = "7b88b0814c9a2b323a9b46c687e726996c255ac8b64aa237dd11c81ed4854760"
|
||||
dependencies = [
|
||||
"leb128",
|
||||
]
|
||||
|
|
@ -2721,22 +2740,22 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "wast"
|
||||
version = "216.0.0"
|
||||
version = "217.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f7eb1f2eecd913fdde0dc6c3439d0f24530a98ac6db6cb3d14d92a5328554a08"
|
||||
checksum = "79004ecebded92d3c710d4841383368c7f04b63d0992ddd6b0c7d5029b7629b7"
|
||||
dependencies = [
|
||||
"bumpalo",
|
||||
"leb128",
|
||||
"memchr",
|
||||
"unicode-width",
|
||||
"wasm-encoder 0.216.0",
|
||||
"wasm-encoder 0.217.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wat"
|
||||
version = "1.216.0"
|
||||
version = "1.217.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac0409090fb5154f95fb5ba3235675fd9e579e731524d63b6a2f653e1280c82a"
|
||||
checksum = "c126271c3d92ca0f7c63e4e462e40c69cca52fd4245fcda730d1cf558fb55088"
|
||||
dependencies = [
|
||||
"wast",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ public-types = { path = "../public-types" }
|
|||
http = "1.1.0"
|
||||
governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
|
||||
tiktoken-rs = "0.5.9"
|
||||
acap = "0.3.0"
|
||||
|
||||
[dev-dependencies]
|
||||
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" }
|
||||
|
|
|
|||
|
|
@ -14,5 +14,6 @@ RUN cargo build --release --target wasm32-wasi
|
|||
FROM envoyproxy/envoy:v1.30-latest
|
||||
COPY --from=builder /envoyfilter/target/wasm32-wasi/release/intelligent_prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm
|
||||
COPY envoyfilter/envoy.yaml /etc/envoy.yaml
|
||||
CMD ["envoy", "-c", "/etc/envoy/envoy.yaml"]
|
||||
# CMD ["envoy", "-c", "/etc/envoy/envoy.yaml"]
|
||||
# CMD ["envoy", "-c", "/etc/envoy/envoy.yaml", "--log-level", "debug"]
|
||||
CMD ["envoy", "-c", "/etc/envoy/envoy.yaml", "--component-log-level", "wasm:debug"]
|
||||
|
|
|
|||
|
|
@ -165,26 +165,12 @@ static_resources:
|
|||
address: embeddingserver
|
||||
port_value: 80
|
||||
hostname: "embeddingserver"
|
||||
- name: qdrant
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: qdrant
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: qdrant
|
||||
port_value: 6333
|
||||
hostname: "qdrant"
|
||||
- name: mistral_7b_instruct
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: qdrant
|
||||
cluster_name: mistral_7b_instruct
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
|
||||
pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6;
|
||||
pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector";
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
|
|
|
|||
|
|
@ -3,22 +3,20 @@ use crate::ratelimit;
|
|||
use crate::stats::{Counter, Gauge, RecordingMetric};
|
||||
use crate::stream_context::StreamContext;
|
||||
use log::debug;
|
||||
use md5::Digest;
|
||||
use open_message_format_embeddings::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::{
|
||||
CallContext, EmbeddingRequest, StoreVectorEmbeddingsRequest, VectorPoint,
|
||||
};
|
||||
use public_types::common_types::EmbeddingType;
|
||||
use public_types::configuration::{Configuration, PromptTarget};
|
||||
use serde_json::to_string;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{OnceLock, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct WasmMetrics {
|
||||
pub active_http_calls: Gauge,
|
||||
pub ratelimited_rq: Counter,
|
||||
|
|
@ -33,11 +31,29 @@ impl WasmMetrics {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CallContext {
|
||||
prompt_target: String,
|
||||
embedding_type: EmbeddingType,
|
||||
}
|
||||
|
||||
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterContext {
|
||||
metrics: Rc<WasmMetrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
config: Option<Configuration>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
}
|
||||
|
||||
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
|
||||
static EMBEDDINGS: OnceLock<RwLock<HashMap<String, EmbeddingTypeMap>>> = OnceLock::new();
|
||||
EMBEDDINGS.get_or_init(|| {
|
||||
let embeddings: HashMap<String, EmbeddingTypeMap> = HashMap::new();
|
||||
RwLock::new(embeddings)
|
||||
})
|
||||
}
|
||||
|
||||
impl FilterContext {
|
||||
|
|
@ -46,75 +62,95 @@ impl FilterContext {
|
|||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
metrics: Rc::new(WasmMetrics::new()),
|
||||
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_prompt_targets(&mut self) {
|
||||
for prompt_target in &self.config.as_ref().unwrap().prompt_targets {
|
||||
for few_shot_example in &prompt_target.few_shot_examples {
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(
|
||||
few_shot_example.to_string(),
|
||||
)),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
// TODO: Handle potential errors
|
||||
let json_data: String = to_string(&embeddings_input).unwrap();
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "20000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
let embedding_request = EmbeddingRequest {
|
||||
create_embedding_request: embeddings_input,
|
||||
// Need to clone prompt target to leave config string intact.
|
||||
prompt_target: prompt_target.clone(),
|
||||
};
|
||||
debug!(
|
||||
"dispatched HTTP call to embedding server token_id={}",
|
||||
token_id
|
||||
);
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext::EmbeddingRequest(embedding_request)
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
let prompt_targets = match self.prompt_targets.read() {
|
||||
Ok(prompt_targets) => prompt_targets,
|
||||
Err(e) => {
|
||||
panic!("Error reading prompt targets: {:?}", e);
|
||||
}
|
||||
};
|
||||
for values in prompt_targets.iter() {
|
||||
let prompt_target = &values.1;
|
||||
|
||||
// schedule embeddings call for prompt target name
|
||||
let token_id = self.schedule_embeddings_call(prompt_target.name.clone());
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext {
|
||||
prompt_target: prompt_target.name.clone(),
|
||||
embedding_type: EmbeddingType::Name,
|
||||
}
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
|
||||
// schedule embeddings call for prompt target description
|
||||
let token_id = self.schedule_embeddings_call(prompt_target.description.clone());
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext {
|
||||
prompt_target: prompt_target.name.clone(),
|
||||
embedding_type: EmbeddingType::Description,
|
||||
}
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding_request_handler(
|
||||
fn schedule_embeddings_call(&self, input: String) -> u32 {
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(input)),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let json_data = to_string(&embeddings_input).unwrap();
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(60),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
token_id
|
||||
}
|
||||
|
||||
fn embedding_response_handler(
|
||||
&mut self,
|
||||
body_size: usize,
|
||||
create_embedding_request: CreateEmbeddingRequest,
|
||||
prompt_target: PromptTarget,
|
||||
embedding_type: EmbeddingType,
|
||||
prompt_target_name: String,
|
||||
) {
|
||||
let prompt_targets = self.prompt_targets.read().unwrap();
|
||||
let prompt_target = prompt_targets.get(&prompt_target_name).unwrap();
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
let mut embedding_response: CreateEmbeddingResponse =
|
||||
|
|
@ -129,111 +165,24 @@ impl FilterContext {
|
|||
}
|
||||
};
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
payload.insert(
|
||||
"prompt-target".to_string(),
|
||||
to_string(&prompt_target).unwrap(),
|
||||
let embeddings = embedding_response.data.remove(0).embedding;
|
||||
log::info!(
|
||||
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
|
||||
prompt_target.name,
|
||||
prompt_target.description,
|
||||
embedding_type
|
||||
);
|
||||
let id: Option<Digest>;
|
||||
match *create_embedding_request.input {
|
||||
CreateEmbeddingRequestInput::String(input) => {
|
||||
id = Some(md5::compute(&input));
|
||||
payload.insert("input".to_string(), input);
|
||||
}
|
||||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = StoreVectorEmbeddingsRequest {
|
||||
points: vec![VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
vector: embedding_response.data.remove(0).embedding,
|
||||
}],
|
||||
};
|
||||
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store/points"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::StoreVectorEmbeddings(create_vector_store_points),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
embeddings_store().write().unwrap().insert(
|
||||
prompt_target.name.clone(),
|
||||
HashMap::from([(embedding_type, embeddings)]),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
panic!("No body in response");
|
||||
}
|
||||
}
|
||||
|
||||
fn create_vector_store_points_handler(&self, body_size: usize) {
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
debug!(
|
||||
"response body: len {:?}",
|
||||
String::from_utf8(body).unwrap().len()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: run once per envoy instance, right now it runs once per worker
|
||||
fn init_vector_store(&mut self) {
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(b"{ \"vectors\": { \"size\": 1024, \"distance\": \"Cosine\"}}"),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for init-vector-store: {:?}", e);
|
||||
}
|
||||
};
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::CreateVectorCollection("prompt_vector_store".to_string()),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for FilterContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
|
|
@ -242,37 +191,18 @@ impl Context for FilterContext {
|
|||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
let callout_data = self
|
||||
.callouts
|
||||
.remove(&token_id)
|
||||
.expect("invalid token_id: {}");
|
||||
debug!("on_http_call_response called with token_id: {:?}", token_id);
|
||||
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
|
||||
match callout_data {
|
||||
CallContext::EmbeddingRequest(EmbeddingRequest {
|
||||
create_embedding_request,
|
||||
prompt_target,
|
||||
}) => {
|
||||
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
|
||||
}
|
||||
CallContext::StoreVectorEmbeddings(_) => {
|
||||
self.create_vector_store_points_handler(body_size)
|
||||
}
|
||||
CallContext::CreateVectorCollection(_) => {
|
||||
let mut http_status_code = "Nil".to_string();
|
||||
self.get_http_call_response_headers()
|
||||
.iter()
|
||||
.for_each(|(k, v)| {
|
||||
if k == ":status" {
|
||||
http_status_code.clone_from(v);
|
||||
}
|
||||
});
|
||||
debug!("CreateVectorCollection response: {}", http_status_code);
|
||||
}
|
||||
}
|
||||
self.embedding_response_handler(
|
||||
body_size,
|
||||
callout_data.embedding_type,
|
||||
callout_data.prompt_target,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -282,6 +212,13 @@ impl RootContext for FilterContext {
|
|||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
|
||||
|
||||
for pt in self.config.clone().unwrap().prompt_targets {
|
||||
self.prompt_targets
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(pt.name.clone(), pt.clone());
|
||||
}
|
||||
|
||||
debug!("set configuration object: {:?}", self.config);
|
||||
|
||||
if let Some(ratelimits_config) = self
|
||||
|
|
@ -301,6 +238,7 @@ impl RootContext for FilterContext {
|
|||
ratelimit_selector: None,
|
||||
callouts: HashMap::new(),
|
||||
metrics: Rc::clone(&self.metrics),
|
||||
prompt_targets: Rc::clone(&self.prompt_targets),
|
||||
}))
|
||||
}
|
||||
|
||||
|
|
@ -314,7 +252,6 @@ impl RootContext for FilterContext {
|
|||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
self.init_vector_store();
|
||||
self.process_prompt_targets();
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ pub trait RecordingMetric: Metric {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Counter {
|
||||
id: u32,
|
||||
}
|
||||
|
|
@ -55,7 +55,7 @@ impl Metric for Counter {
|
|||
|
||||
impl IncrementingMetric for Counter {}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Gauge {
|
||||
id: u32,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
use crate::consts::{
|
||||
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL,
|
||||
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE,
|
||||
USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::WasmMetrics;
|
||||
use crate::filter_context::{embeddings_store, WasmMetrics};
|
||||
use crate::ratelimit;
|
||||
use crate::ratelimit::Header;
|
||||
use crate::stats::IncrementingMetric;
|
||||
use crate::tokenizer;
|
||||
use acap::cos;
|
||||
use http::StatusCode;
|
||||
use log::{debug, error, info, warn};
|
||||
use open_message_format_embeddings::models::{
|
||||
|
|
@ -15,31 +16,31 @@ use open_message_format_embeddings::models::{
|
|||
};
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::open_ai::{ChatCompletions, Message};
|
||||
use public_types::common_types::{
|
||||
open_ai::{ChatCompletions, Message},
|
||||
SearchPointsRequest, SearchPointsResponse,
|
||||
};
|
||||
use public_types::common_types::{
|
||||
BoltFCResponse, BoltFCToolsCall, ToolParameter, ToolParameters, ToolsDefinition,
|
||||
BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition,
|
||||
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||
};
|
||||
use public_types::configuration::{PromptTarget, PromptType};
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
use std::sync::RwLock;
|
||||
use std::time::Duration;
|
||||
|
||||
enum RequestType {
|
||||
GetEmbedding,
|
||||
SearchPoints,
|
||||
enum ResponseHandlerType {
|
||||
GetEmbeddings,
|
||||
FunctionResolver,
|
||||
FunctionCallResponse,
|
||||
FunctionCall,
|
||||
ZeroShotIntent,
|
||||
}
|
||||
|
||||
pub struct CallContext {
|
||||
request_type: RequestType,
|
||||
response_handler_type: ResponseHandlerType,
|
||||
user_message: Option<String>,
|
||||
prompt_target: Option<PromptTarget>,
|
||||
request_body: ChatCompletions,
|
||||
similarity_scores: Option<Vec<(String, f64)>>,
|
||||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
|
|
@ -47,6 +48,7 @@ pub struct StreamContext {
|
|||
pub ratelimit_selector: Option<Header>,
|
||||
pub callouts: HashMap<u32, CallContext>,
|
||||
pub metrics: Rc<WasmMetrics>,
|
||||
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -61,7 +63,6 @@ impl StreamContext {
|
|||
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
|
||||
// manipulate the body in benign ways e.g., compression.
|
||||
self.set_http_request_header("content-length", None);
|
||||
// self.set_http_request_header("authorization", None);
|
||||
}
|
||||
|
||||
fn modify_path_header(&mut self) {
|
||||
|
|
@ -85,7 +86,7 @@ impl StreamContext {
|
|||
});
|
||||
}
|
||||
|
||||
fn send_server_error(&mut self, error: String) {
|
||||
fn send_server_error(&self, error: String) {
|
||||
debug!("server error occurred: {}", error);
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
|
|
@ -103,30 +104,85 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let search_points_request = SearchPointsRequest {
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
limit: 10,
|
||||
with_payload: true,
|
||||
};
|
||||
let embeddings_vector = &embedding_response.data[0].embedding;
|
||||
|
||||
let json_data: String = match serde_json::to_string(&search_points_request) {
|
||||
Ok(json_data) => json_data,
|
||||
debug!(
|
||||
"embedding model: {}, vector length: {:?}",
|
||||
embedding_response.model,
|
||||
embeddings_vector.len()
|
||||
);
|
||||
|
||||
let prompt_target_embeddings = match embeddings_store().read() {
|
||||
Ok(embeddings) => embeddings,
|
||||
Err(e) => {
|
||||
self.send_server_error(format!("Error serializing search_points_request: {:?}", e));
|
||||
let error_message = format!("Error reading embeddings store: {:?}", e);
|
||||
warn!("{}", error_message);
|
||||
self.send_server_error(error_message);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let path = format!("/collections/{}/points/search", DEFAULT_COLLECTION_NAME);
|
||||
let prompt_targets = match self.prompt_targets.read() {
|
||||
Ok(prompt_targets) => prompt_targets,
|
||||
Err(e) => {
|
||||
let error_message = format!("Error reading prompt targets: {:?}", e);
|
||||
warn!("{}", error_message);
|
||||
self.send_server_error(error_message);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let prompt_target_names = prompt_targets
|
||||
.iter()
|
||||
.map(|(name, _)| name.clone())
|
||||
.collect();
|
||||
|
||||
let similarity_scores: Vec<(String, f64)> = prompt_targets
|
||||
.iter()
|
||||
.map(|(prompt_name, _prompt_target)| {
|
||||
let default_embeddings = HashMap::new();
|
||||
let pte = prompt_target_embeddings
|
||||
.get(prompt_name)
|
||||
.unwrap_or(&default_embeddings);
|
||||
let description_embeddings = pte.get(&EmbeddingType::Description);
|
||||
let similarity_score_description = cos::cosine_similarity(
|
||||
&embeddings_vector,
|
||||
&description_embeddings.unwrap_or(&vec![0.0]),
|
||||
);
|
||||
(prompt_name.clone(), similarity_score_description)
|
||||
})
|
||||
.collect();
|
||||
|
||||
debug!(
|
||||
"similarity scores based on description embeddings match: {:?}",
|
||||
similarity_scores
|
||||
);
|
||||
|
||||
callout_context.similarity_scores = Some(similarity_scores);
|
||||
|
||||
let zero_shot_classification_request = ZeroShotClassificationRequest {
|
||||
// Need to clone into input because user_message is used below.
|
||||
input: callout_context.user_message.as_ref().unwrap().clone(),
|
||||
model: String::from(DEFAULT_INTENT_MODEL),
|
||||
labels: prompt_target_names,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing zero shot request: {}", error);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", &path),
|
||||
(":authority", "qdrant"),
|
||||
(":path", "/zeroshot"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
|
|
@ -134,39 +190,60 @@ impl StreamContext {
|
|||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
|
||||
panic!(
|
||||
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"dispatched HTTP call to embedding server for zero-shot-intent-detection token_id={}",
|
||||
token_id
|
||||
);
|
||||
|
||||
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
|
||||
|
||||
callout_context.request_type = RequestType::SearchPoints;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
panic!(
|
||||
"duplicate token_id={} in embedding server requests",
|
||||
token_id
|
||||
)
|
||||
}
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
}
|
||||
|
||||
fn search_points_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
let search_points_response: SearchPointsResponse = match serde_json::from_slice(&body) {
|
||||
Ok(search_points_response) => search_points_response,
|
||||
Err(e) => {
|
||||
self.send_server_error(format!(
|
||||
"Error deserializing search_points_response: {:?}",
|
||||
e
|
||||
));
|
||||
fn zero_shot_intent_detection_resp_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
mut callout_context: CallContext,
|
||||
) {
|
||||
let zeroshot_intent_response: ZeroShotClassificationResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(zeroshot_response) => zeroshot_response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Error deserializing zeroshot intent detection response: {:?}",
|
||||
e
|
||||
);
|
||||
info!("body: {:?}", String::from_utf8(body).unwrap());
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
debug!("zeroshot intent response: {:?}", zeroshot_intent_response);
|
||||
|
||||
let search_results = &search_points_response.result;
|
||||
let prompt_target_similarity_score = zeroshot_intent_response.predicted_class_score * 0.7
|
||||
+ callout_context.similarity_scores.as_ref().unwrap()[0].1 * 0.3;
|
||||
|
||||
if search_results.is_empty() {
|
||||
info!("No prompt target matched");
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
debug!(
|
||||
"similarity score: {}, intent score: {}, description embedding score: {}",
|
||||
prompt_target_similarity_score,
|
||||
zeroshot_intent_response.predicted_class_score,
|
||||
callout_context.similarity_scores.as_ref().unwrap()[0].1
|
||||
);
|
||||
|
||||
let prompt_target_name = zeroshot_intent_response.predicted_class.clone();
|
||||
|
||||
info!("similarity score: {}", search_results[0].score);
|
||||
// Check to see who responded to user message. This will help us identify if control should be passed to Bolt FC or not.
|
||||
// If the last message was from Bolt FC, then Bolt FC is handling the conversation (possibly for parameter collection).
|
||||
let mut bolt_assistant = false;
|
||||
|
|
@ -175,7 +252,6 @@ impl StreamContext {
|
|||
let latest_assistant_message = &messages[messages.len() - 2];
|
||||
if let Some(model) = latest_assistant_message.model.as_ref() {
|
||||
if model.starts_with("Bolt") {
|
||||
info!("Bolt assistant message found");
|
||||
bolt_assistant = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -183,23 +259,30 @@ impl StreamContext {
|
|||
info!("no assistant message found, probably first interaction");
|
||||
}
|
||||
|
||||
if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant {
|
||||
info!(
|
||||
"prompt target below threshold: {}",
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD
|
||||
);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
let prompt_target_str = search_results[0].payload.get("prompt-target").unwrap();
|
||||
let prompt_target: PromptTarget = match serde_json::from_slice(prompt_target_str.as_bytes())
|
||||
{
|
||||
Ok(prompt_target) => prompt_target,
|
||||
Err(e) => {
|
||||
self.send_server_error(format!("Error deserializing prompt_target: {:?}", e));
|
||||
// check to ensure that the prompt target similarity score is above the threshold
|
||||
if prompt_target_similarity_score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant {
|
||||
// if bolt fc responded to the user message, then we don't need to check the similarity score
|
||||
// it may be that bolt fc is handling the conversation for parameter collection
|
||||
if bolt_assistant {
|
||||
info!("bolt assistant is handling the conversation");
|
||||
} else {
|
||||
info!(
|
||||
"prompt target below threshold: {}, continue conversation with user",
|
||||
prompt_target_similarity_score,
|
||||
);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&prompt_target_name)
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
||||
info!(
|
||||
"prompt_target name: {:?}, type: {:?}",
|
||||
prompt_target.name, prompt_target.prompt_type
|
||||
|
|
@ -231,7 +314,7 @@ impl StreamContext {
|
|||
|
||||
let tools_defintion: ToolsDefinition = ToolsDefinition {
|
||||
name: prompt_target.name.clone(),
|
||||
description: prompt_target.description.clone().unwrap_or("".to_string()),
|
||||
description: prompt_target.description.clone(),
|
||||
parameters: tools_parameters,
|
||||
};
|
||||
|
||||
|
|
@ -283,7 +366,7 @@ impl StreamContext {
|
|||
BOLT_FC_CLUSTER, token_id
|
||||
);
|
||||
|
||||
callout_context.request_type = RequestType::FunctionResolver;
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
|
||||
callout_context.prompt_target = Some(prompt_target);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
|
|
@ -342,7 +425,7 @@ impl StreamContext {
|
|||
{
|
||||
warn!("boltfc did not extract required parameter: {}", param.name);
|
||||
return self.send_http_response(
|
||||
StatusCode::BAD_REQUEST.as_u16().into(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
Some("missing required parameter".as_bytes()),
|
||||
);
|
||||
|
|
@ -380,7 +463,7 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
callout_context.request_type = RequestType::FunctionCallResponse;
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
|
|
@ -468,7 +551,6 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
debug!("sending request to openai: msg {}", json_string);
|
||||
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
|
|
@ -582,10 +664,11 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
|
||||
let call_context = CallContext {
|
||||
request_type: RequestType::GetEmbedding,
|
||||
response_handler_type: ResponseHandlerType::GetEmbeddings,
|
||||
user_message: Some(user_message),
|
||||
prompt_target: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
};
|
||||
if self.callouts.insert(token_id, call_context).is_some() {
|
||||
panic!(
|
||||
|
|
@ -593,6 +676,7 @@ impl HttpContext for StreamContext {
|
|||
token_id
|
||||
)
|
||||
}
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
|
||||
Action::Pause
|
||||
|
|
@ -611,18 +695,24 @@ impl Context for StreamContext {
|
|||
self.metrics.active_http_calls.increment(-1);
|
||||
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
match callout_context.request_type {
|
||||
RequestType::GetEmbedding => self.embeddings_handler(body, callout_context),
|
||||
RequestType::SearchPoints => self.search_points_handler(body, callout_context),
|
||||
RequestType::FunctionResolver => {
|
||||
match callout_context.response_handler_type {
|
||||
ResponseHandlerType::GetEmbeddings => {
|
||||
self.embeddings_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::FunctionResolver => {
|
||||
self.function_resolver_handler(body, callout_context)
|
||||
}
|
||||
RequestType::FunctionCallResponse => {
|
||||
ResponseHandlerType::FunctionCall => {
|
||||
self.function_call_response_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::ZeroShotIntent => {
|
||||
self.zero_shot_intent_detection_resp_handler(body, callout_context)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("No response body in inline HTTP request");
|
||||
let error_message = "No response body in inline HTTP request";
|
||||
warn!("{}", error_message);
|
||||
self.send_server_error(error_message.to_owned());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,16 +8,10 @@ use proxy_wasm_test_framework::tester::{self, Tester};
|
|||
use proxy_wasm_test_framework::types::{
|
||||
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
|
||||
};
|
||||
use public_types::{
|
||||
common_types::{
|
||||
open_ai::Message, BoltFCResponse, BoltFCToolsCall, IntOrString, ToolCallDetail,
|
||||
},
|
||||
configuration::{self, Endpoint, PromptTarget},
|
||||
};
|
||||
use public_types::{
|
||||
common_types::{SearchPointResult, SearchPointsResponse},
|
||||
configuration::Configuration,
|
||||
use public_types::common_types::{
|
||||
open_ai::Message, BoltFCResponse, BoltFCToolsCall, IntOrString, ToolCallDetail,
|
||||
};
|
||||
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
|
||||
use serial_test::serial;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
|
@ -118,59 +112,39 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&embeddings_response_buffer))
|
||||
.expect_http_call(Some("qdrant"), None, None, None, None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("embeddingserver"), None, None, None, None)
|
||||
.returning(Some(2))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let prompt_target = PromptTarget {
|
||||
name: String::from("test-prompt-target"),
|
||||
description: None,
|
||||
prompt_type: configuration::PromptType::FunctionResolver,
|
||||
few_shot_examples: vec![],
|
||||
parameters: Some(vec![configuration::Parameter {
|
||||
name: String::from("test-entity"),
|
||||
parameter_type: Some(String::from("string")),
|
||||
description: String::from("test-description"),
|
||||
required: Some(true),
|
||||
}]),
|
||||
endpoint: Some(Endpoint {
|
||||
cluster: String::from("test-endpoint-cluster"),
|
||||
path: None,
|
||||
method: None,
|
||||
}),
|
||||
system_prompt: None,
|
||||
let zero_shot_response = ZeroShotClassificationResponse {
|
||||
predicted_class: "weather_forecast".to_string(),
|
||||
predicted_class_score: 0.1,
|
||||
scores: HashMap::new(),
|
||||
model: "test-model".to_string(),
|
||||
};
|
||||
let prompt_target_str = serde_json::to_string(&prompt_target).unwrap();
|
||||
let search_points_response = SearchPointsResponse {
|
||||
status: String::new(),
|
||||
time: 0.0,
|
||||
result: vec![SearchPointResult {
|
||||
id: String::new(),
|
||||
version: 0,
|
||||
score: 0.7,
|
||||
payload: HashMap::from([(String::from("prompt-target"), prompt_target_str)]),
|
||||
}],
|
||||
};
|
||||
let search_points_response_buffer = serde_json::to_string(&search_points_response).unwrap();
|
||||
let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
2,
|
||||
0,
|
||||
search_points_response_buffer.len() as i32,
|
||||
zeroshot_intent_detection_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&search_points_response_buffer))
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.returning(Some(&zeroshot_intent_detection_buffer))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_http_call(Some("bolt_fc_1b"), None, None, None, None)
|
||||
.returning(Some(3))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
|
@ -199,20 +173,22 @@ system_prompt: |
|
|||
prompt_targets:
|
||||
- type: function_resolver
|
||||
name: weather_forecast
|
||||
few_shot_examples:
|
||||
- what is the weather in New York?
|
||||
description: This resolver provides weather forecast information.
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
entities:
|
||||
- name: location
|
||||
parameters:
|
||||
- name: city
|
||||
required: true
|
||||
description: "The location for which the weather is requested"
|
||||
description: The city for which the weather forecast is requested.
|
||||
- name: days
|
||||
description: The number of days for which the weather forecast is requested.
|
||||
- name: units
|
||||
description: The units in which the weather forecast is requested.
|
||||
|
||||
- type: function_resolver
|
||||
name: weather_forecast_2
|
||||
few_shot_examples:
|
||||
- what is the weather in New York?
|
||||
description: This resolver provides weather forecast information.
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
|
|
@ -450,10 +426,10 @@ fn request_ratelimited() {
|
|||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let tool_call_detail = vec![ToolCallDetail {
|
||||
name: String::from("test-tool"),
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("test-entity"),
|
||||
IntOrString::Text(String::from("test-value")),
|
||||
String::from("city"),
|
||||
IntOrString::Text(String::from("seattle")),
|
||||
)]),
|
||||
}];
|
||||
|
||||
|
|
@ -485,7 +461,7 @@ fn request_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("test-endpoint-cluster"), None, None, None, None)
|
||||
.expect_http_call(Some("weatherhost"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
|
|
@ -555,10 +531,10 @@ fn request_not_ratelimited() {
|
|||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let tool_call_detail = vec![ToolCallDetail {
|
||||
name: String::from("test-tool"),
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("test-entity"),
|
||||
IntOrString::Text(String::from("test-value")),
|
||||
String::from("city"),
|
||||
IntOrString::Text(String::from("seattle")),
|
||||
)]),
|
||||
}];
|
||||
|
||||
|
|
@ -590,7 +566,7 @@ fn request_not_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("test-endpoint-cluster"), None, None, None, None)
|
||||
.expect_http_call(Some("weatherhost"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
|
|
@ -608,7 +584,7 @@ fn request_not_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
// .expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,18 @@
|
|||
use crate::configuration::PromptTarget;
|
||||
use open_message_format_embeddings::models::CreateEmbeddingRequest;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
pub create_embedding_request: CreateEmbeddingRequest,
|
||||
pub prompt_target: PromptTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub enum EmbeddingType {
|
||||
Name,
|
||||
Description,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorPoint {
|
||||
pub id: String,
|
||||
|
|
@ -21,21 +25,6 @@ pub struct StoreVectorEmbeddingsRequest {
|
|||
pub points: Vec<VectorPoint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum CallContext {
|
||||
EmbeddingRequest(EmbeddingRequest),
|
||||
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
|
||||
CreateVectorCollection(String),
|
||||
}
|
||||
|
||||
// https://api.qdrant.tech/master/api-reference/search/points
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointsRequest {
|
||||
pub vector: Vec<f64>,
|
||||
pub limit: i32,
|
||||
pub with_payload: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointResult {
|
||||
|
|
@ -45,13 +34,6 @@ pub struct SearchPointResult {
|
|||
pub payload: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointsResponse {
|
||||
pub result: Vec<SearchPointResult>,
|
||||
pub status: String,
|
||||
pub time: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolParameter {
|
||||
#[serde(rename = "type")]
|
||||
|
|
@ -125,3 +107,18 @@ pub mod open_ai {
|
|||
pub model: Option<String>,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ZeroShotClassificationRequest {
|
||||
pub input: String,
|
||||
pub labels: Vec<String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ZeroShotClassificationResponse {
|
||||
pub predicted_class: String,
|
||||
pub predicted_class_score: f64,
|
||||
pub scores: HashMap<String, f64>,
|
||||
pub model: String,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -99,8 +99,7 @@ pub struct PromptTarget {
|
|||
#[serde(rename = "type")]
|
||||
pub prompt_type: PromptType,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub few_shot_examples: Vec<String>,
|
||||
pub description: String,
|
||||
pub parameters: Option<Vec<Parameter>>,
|
||||
pub endpoint: Option<Endpoint>,
|
||||
pub system_prompt: Option<String>,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue