diff --git a/demos/function-calling/docker-compose.yaml b/demos/function-calling/docker-compose.yaml index 541a5859..45d5c7c9 100644 --- a/demos/function-calling/docker-compose.yaml +++ b/demos/function-calling/docker-compose.yaml @@ -71,13 +71,6 @@ services: profiles: - manual - qdrant: - image: qdrant/qdrant - hostname: vector-db - ports: - - 16333:6333 - - 16334:6334 - chatbot-ui: build: context: ../../chatbot-ui diff --git a/embedding-server/app/load_models.py b/embedding-server/app/load_models.py index 24cf1671..a0467bfc 100644 --- a/embedding-server/app/load_models.py +++ b/embedding-server/app/load_models.py @@ -1,6 +1,7 @@ import os import sentence_transformers from gliner import GLiNER +from transformers import pipeline def load_transformers(models = os.getenv("MODELS", "BAAI/bge-large-en-v1.5")): transformers = {} @@ -17,3 +18,11 @@ def load_ner_models(models = os.getenv("NER_MODELS", "urchade/gliner_large-v2.1" ner_models[model] = GLiNER.from_pretrained(model) return ner_models + +def load_zero_shot_models(models = os.getenv("ZERO_SHOT_MODELS", "tasksource/deberta-base-long-nli")): + zero_shot_models = {} + + for model in models.split(','): + zero_shot_models[model] = pipeline("zero-shot-classification",model=model) + + return zero_shot_models diff --git a/embedding-server/app/main.py b/embedding-server/app/main.py index 1328872f..37a5fd4b 100644 --- a/embedding-server/app/main.py +++ b/embedding-server/app/main.py @@ -1,11 +1,13 @@ import random from fastapi import FastAPI, Response, HTTPException from pydantic import BaseModel -from load_models import load_ner_models, load_transformers +from load_models import load_ner_models, load_transformers, load_zero_shot_models from datetime import date, timedelta +import string transformers = load_transformers() ner_models = load_ner_models() +zero_shot_models = load_zero_shot_models() app = FastAPI() @@ -81,6 +83,42 @@ async def ner(req: NERRequest, res: Response): "object": "list", } +class ZeroShotRequest(BaseModel): + input: str + labels: list[str] + model: str + + +def remove_punctuations(s, lower=True): + s = s.translate(str.maketrans(string.punctuation, " " * len(string.punctuation))) + s = " ".join(s.split()) + if lower: + s = s.lower() + return s + + +@app.post("/zeroshot") +async def zeroshot(req: ZeroShotRequest, res: Response): + if req.model not in zero_shot_models: + raise HTTPException(status_code=400, detail="unknown model: " + req.model) + + classifier = zero_shot_models[req.model] + labels_without_punctuations = [remove_punctuations(label) for label in req.labels] + predicted_classes = classifier(req.input, candidate_labels=labels_without_punctuations, multi_label=True) + label_map = dict(zip(labels_without_punctuations, req.labels)) + + orig_map = [label_map[label] for label in predicted_classes["labels"]] + final_scores = dict(zip(orig_map, predicted_classes["scores"])) + predicted_class = label_map[predicted_classes["labels"][0]] + + return { + "predicted_class": predicted_class, + "predicted_class_score": final_scores[predicted_class], + "scores": final_scores, + "model": req.model, + } + + class WeatherRequest(BaseModel): city: str diff --git a/envoyfilter/Cargo.lock b/envoyfilter/Cargo.lock index dbf8ae79..a9391094 100644 --- a/envoyfilter/Cargo.lock +++ b/envoyfilter/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "acap" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6851a0b3b2d5729a0b7e61e3c36923ed9d72240146b0efda61121b0b84ad595d" +dependencies = [ + "num-traits", +] + [[package]] name = "addr2line" version = "0.21.0" @@ -13,18 +22,18 @@ dependencies = [ [[package]] name = "addr2line" -version = "0.22.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +checksum = "f5fb1d8e4442bd405fdfd1dacb42792696b0cf9cb15882e5d097b742a676d375" dependencies = [ - "gimli 0.29.0", + "gimli 0.31.0", ] [[package]] -name = "adler" -version = "1.0.2" +name = "adler2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] name = "ahash" @@ -116,17 +125,17 @@ checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" [[package]] name = "backtrace" -version = "0.3.73" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" dependencies = [ - "addr2line 0.22.0", - "cc", + "addr2line 0.24.1", "cfg-if 1.0.0", "libc", "miniz_oxide", "object", "rustc-demangle", + "windows-targets", ] [[package]] @@ -208,9 +217,9 @@ checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" [[package]] name = "cc" -version = "1.1.17" +version = "1.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a93fe60e2fc87b6ba2c117f67ae14f66e3fc7d6a1e612a25adb238cc980eadb3" +checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476" dependencies = [ "jobserver", "libc", @@ -718,9 +727,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.29.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" +checksum = "32085ea23f3234fc7846555e85283ba4de91e21016dc0455a16286d87a292d64" [[package]] name = "governor" @@ -910,9 +919,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" +checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba" dependencies = [ "bytes", "futures-channel", @@ -959,6 +968,7 @@ dependencies = [ name = "intelligent-prompt-gateway" version = "0.1.0" dependencies = [ + "acap", "governor", "http", "log", @@ -976,9 +986,9 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.9.0" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" [[package]] name = "itertools" @@ -1137,11 +1147,11 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" dependencies = [ - "adler", + "adler2", ] [[package]] @@ -1194,6 +1204,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "object" version = "0.36.4" @@ -1648,9 +1667,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.12" +version = "0.23.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044" +checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" dependencies = [ "once_cell", "rustls-pki-types", @@ -1677,9 +1696,9 @@ checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" [[package]] name = "rustls-webpki" -version = "0.102.7" +version = "0.102.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84678086bd54edf2b415183ed7a94d0efb049f1b646a33e22a36f3794be6ae56" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" dependencies = [ "ring", "rustls-pki-types", @@ -1694,20 +1713,20 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "scc" -version = "2.1.16" +version = "2.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aeb7ac86243095b70a7920639507b71d51a63390d1ba26c4f60a552fbb914a37" +checksum = "0c947adb109a8afce5fc9c7bf951f87f146e9147b3a6a58413105628fb1d1e66" dependencies = [ "sdd", ] [[package]] name = "schannel" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1718,9 +1737,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "sdd" -version = "3.0.2" +version = "3.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0495e4577c672de8254beb68d01a9b62d0e8a13c099edecdbedccce3223cd29f" +checksum = "60a7b59a5d9b0099720b417b6325d91a52cbf5b3dcb5041d864be53eefa58abc" [[package]] name = "security-framework" @@ -2431,9 +2450,9 @@ dependencies = [ [[package]] name = "wasm-encoder" -version = "0.216.0" +version = "0.217.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04c23aebea22c8a75833ae08ed31ccc020835b12a41999e58c31464271b94a88" +checksum = "7b88b0814c9a2b323a9b46c687e726996c255ac8b64aa237dd11c81ed4854760" dependencies = [ "leb128", ] @@ -2721,22 +2740,22 @@ dependencies = [ [[package]] name = "wast" -version = "216.0.0" +version = "217.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7eb1f2eecd913fdde0dc6c3439d0f24530a98ac6db6cb3d14d92a5328554a08" +checksum = "79004ecebded92d3c710d4841383368c7f04b63d0992ddd6b0c7d5029b7629b7" dependencies = [ "bumpalo", "leb128", "memchr", "unicode-width", - "wasm-encoder 0.216.0", + "wasm-encoder 0.217.0", ] [[package]] name = "wat" -version = "1.216.0" +version = "1.217.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac0409090fb5154f95fb5ba3235675fd9e579e731524d63b6a2f653e1280c82a" +checksum = "c126271c3d92ca0f7c63e4e462e40c69cca52fd4245fcda730d1cf558fb55088" dependencies = [ "wast", ] diff --git a/envoyfilter/Cargo.toml b/envoyfilter/Cargo.toml index 0c59bc22..62786171 100644 --- a/envoyfilter/Cargo.toml +++ b/envoyfilter/Cargo.toml @@ -19,6 +19,7 @@ public-types = { path = "../public-types" } http = "1.1.0" governor = { version = "0.6.3", default-features = false, features = ["no_std"]} tiktoken-rs = "0.5.9" +acap = "0.3.0" [dev-dependencies] proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" } diff --git a/envoyfilter/Dockerfile b/envoyfilter/Dockerfile index 9d48f20b..5e542c54 100644 --- a/envoyfilter/Dockerfile +++ b/envoyfilter/Dockerfile @@ -14,5 +14,6 @@ RUN cargo build --release --target wasm32-wasi FROM envoyproxy/envoy:v1.30-latest COPY --from=builder /envoyfilter/target/wasm32-wasi/release/intelligent_prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm COPY envoyfilter/envoy.yaml /etc/envoy.yaml -CMD ["envoy", "-c", "/etc/envoy/envoy.yaml"] +# CMD ["envoy", "-c", "/etc/envoy/envoy.yaml"] # CMD ["envoy", "-c", "/etc/envoy/envoy.yaml", "--log-level", "debug"] +CMD ["envoy", "-c", "/etc/envoy/envoy.yaml", "--component-log-level", "wasm:debug"] diff --git a/envoyfilter/envoy.template.yaml b/envoyfilter/envoy.template.yaml index 50610eb0..318fc2b7 100644 --- a/envoyfilter/envoy.template.yaml +++ b/envoyfilter/envoy.template.yaml @@ -165,26 +165,12 @@ static_resources: address: embeddingserver port_value: 80 hostname: "embeddingserver" - - name: qdrant - connect_timeout: 5s - type: STRICT_DNS - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: qdrant - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: qdrant - port_value: 6333 - hostname: "qdrant" - name: mistral_7b_instruct connect_timeout: 5s type: STRICT_DNS lb_policy: ROUND_ROBIN load_assignment: - cluster_name: qdrant + cluster_name: mistral_7b_instruct endpoints: - lb_endpoints: - endpoint: diff --git a/envoyfilter/src/consts.rs b/envoyfilter/src/consts.rs index 732a2bc5..363875dd 100644 --- a/envoyfilter/src/consts.rs +++ b/envoyfilter/src/consts.rs @@ -1,6 +1,6 @@ pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5"; -pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store"; -pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6; +pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli"; +pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8; pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector"; pub const SYSTEM_ROLE: &str = "system"; pub const USER_ROLE: &str = "user"; diff --git a/envoyfilter/src/filter_context.rs b/envoyfilter/src/filter_context.rs index cbe24ecd..2b9719c8 100644 --- a/envoyfilter/src/filter_context.rs +++ b/envoyfilter/src/filter_context.rs @@ -3,22 +3,20 @@ use crate::ratelimit; use crate::stats::{Counter, Gauge, RecordingMetric}; use crate::stream_context::StreamContext; use log::debug; -use md5::Digest; use open_message_format_embeddings::models::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; use proxy_wasm::traits::*; use proxy_wasm::types::*; -use public_types::common_types::{ - CallContext, EmbeddingRequest, StoreVectorEmbeddingsRequest, VectorPoint, -}; +use public_types::common_types::EmbeddingType; use public_types::configuration::{Configuration, PromptTarget}; use serde_json::to_string; use std::collections::HashMap; use std::rc::Rc; +use std::sync::{OnceLock, RwLock}; use std::time::Duration; -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] pub struct WasmMetrics { pub active_http_calls: Gauge, pub ratelimited_rq: Counter, @@ -33,11 +31,29 @@ impl WasmMetrics { } } +#[derive(Debug)] +struct CallContext { + prompt_target: String, + embedding_type: EmbeddingType, +} + +pub type EmbeddingTypeMap = HashMap>; + +#[derive(Debug)] pub struct FilterContext { metrics: Rc, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. callouts: HashMap, config: Option, + prompt_targets: Rc>>, +} + +pub fn embeddings_store() -> &'static RwLock> { + static EMBEDDINGS: OnceLock>> = OnceLock::new(); + EMBEDDINGS.get_or_init(|| { + let embeddings: HashMap = HashMap::new(); + RwLock::new(embeddings) + }) } impl FilterContext { @@ -46,75 +62,95 @@ impl FilterContext { callouts: HashMap::new(), config: None, metrics: Rc::new(WasmMetrics::new()), + prompt_targets: Rc::new(RwLock::new(HashMap::new())), } } fn process_prompt_targets(&mut self) { - for prompt_target in &self.config.as_ref().unwrap().prompt_targets { - for few_shot_example in &prompt_target.few_shot_examples { - let embeddings_input = CreateEmbeddingRequest { - input: Box::new(CreateEmbeddingRequestInput::String( - few_shot_example.to_string(), - )), - model: String::from(DEFAULT_EMBEDDING_MODEL), - encoding_format: None, - dimensions: None, - user: None, - }; - - // TODO: Handle potential errors - let json_data: String = to_string(&embeddings_input).unwrap(); - - let token_id = match self.dispatch_http_call( - "embeddingserver", - vec![ - (":method", "POST"), - (":path", "/embeddings"), - (":authority", "embeddingserver"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "20000"), - ], - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ) { - Ok(token_id) => token_id, - Err(e) => { - panic!("Error dispatching HTTP call: {:?}", e); - } - }; - let embedding_request = EmbeddingRequest { - create_embedding_request: embeddings_input, - // Need to clone prompt target to leave config string intact. - prompt_target: prompt_target.clone(), - }; - debug!( - "dispatched HTTP call to embedding server token_id={}", - token_id - ); - if self - .callouts - .insert(token_id, { - CallContext::EmbeddingRequest(embedding_request) - }) - .is_some() - { - panic!("duplicate token_id") - } - self.metrics - .active_http_calls - .record(self.callouts.len().try_into().unwrap()); + let prompt_targets = match self.prompt_targets.read() { + Ok(prompt_targets) => prompt_targets, + Err(e) => { + panic!("Error reading prompt targets: {:?}", e); } + }; + for values in prompt_targets.iter() { + let prompt_target = &values.1; + + // schedule embeddings call for prompt target name + let token_id = self.schedule_embeddings_call(prompt_target.name.clone()); + if self + .callouts + .insert(token_id, { + CallContext { + prompt_target: prompt_target.name.clone(), + embedding_type: EmbeddingType::Name, + } + }) + .is_some() + { + panic!("duplicate token_id") + } + + // schedule embeddings call for prompt target description + let token_id = self.schedule_embeddings_call(prompt_target.description.clone()); + if self + .callouts + .insert(token_id, { + CallContext { + prompt_target: prompt_target.name.clone(), + embedding_type: EmbeddingType::Description, + } + }) + .is_some() + { + panic!("duplicate token_id") + } + + self.metrics + .active_http_calls + .record(self.callouts.len().try_into().unwrap()); } } - fn embedding_request_handler( + fn schedule_embeddings_call(&self, input: String) -> u32 { + let embeddings_input = CreateEmbeddingRequest { + input: Box::new(CreateEmbeddingRequestInput::String(input)), + model: String::from(DEFAULT_EMBEDDING_MODEL), + encoding_format: None, + dimensions: None, + user: None, + }; + + let json_data = to_string(&embeddings_input).unwrap(); + let token_id = match self.dispatch_http_call( + "embeddingserver", + vec![ + (":method", "POST"), + (":path", "/embeddings"), + (":authority", "embeddingserver"), + ("content-type", "application/json"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ], + Some(json_data.as_bytes()), + vec![], + Duration::from_secs(60), + ) { + Ok(token_id) => token_id, + Err(e) => { + panic!("Error dispatching HTTP call: {:?}", e); + } + }; + token_id + } + + fn embedding_response_handler( &mut self, body_size: usize, - create_embedding_request: CreateEmbeddingRequest, - prompt_target: PromptTarget, + embedding_type: EmbeddingType, + prompt_target_name: String, ) { + let prompt_targets = self.prompt_targets.read().unwrap(); + let prompt_target = prompt_targets.get(&prompt_target_name).unwrap(); if let Some(body) = self.get_http_call_response_body(0, body_size) { if !body.is_empty() { let mut embedding_response: CreateEmbeddingResponse = @@ -129,111 +165,24 @@ impl FilterContext { } }; - let mut payload: HashMap = HashMap::new(); - payload.insert( - "prompt-target".to_string(), - to_string(&prompt_target).unwrap(), + let embeddings = embedding_response.data.remove(0).embedding; + log::info!( + "Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}", + prompt_target.name, + prompt_target.description, + embedding_type ); - let id: Option; - match *create_embedding_request.input { - CreateEmbeddingRequestInput::String(input) => { - id = Some(md5::compute(&input)); - payload.insert("input".to_string(), input); - } - CreateEmbeddingRequestInput::Array(_) => todo!(), - } - let create_vector_store_points = StoreVectorEmbeddingsRequest { - points: vec![VectorPoint { - id: format!("{:x}", id.unwrap()), - payload, - vector: embedding_response.data.remove(0).embedding, - }], - }; - let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors - let token_id = match self.dispatch_http_call( - "qdrant", - vec![ - (":method", "PUT"), - (":path", "/collections/prompt_vector_store/points"), - (":authority", "qdrant"), - ("content-type", "application/json"), - ], - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ) { - Ok(token_id) => token_id, - Err(e) => { - panic!("Error dispatching HTTP call: {:?}", e); - } - }; - - if self - .callouts - .insert( - token_id, - CallContext::StoreVectorEmbeddings(create_vector_store_points), - ) - .is_some() - { - panic!("duplicate token_id") - } - self.metrics - .active_http_calls - .record(self.callouts.len().try_into().unwrap()); + embeddings_store().write().unwrap().insert( + prompt_target.name.clone(), + HashMap::from([(embedding_type, embeddings)]), + ); } } else { panic!("No body in response"); } } - - fn create_vector_store_points_handler(&self, body_size: usize) { - if let Some(body) = self.get_http_call_response_body(0, body_size) { - if !body.is_empty() { - debug!( - "response body: len {:?}", - String::from_utf8(body).unwrap().len() - ); - } - } - } - - //TODO: run once per envoy instance, right now it runs once per worker - fn init_vector_store(&mut self) { - let token_id = match self.dispatch_http_call( - "qdrant", - vec![ - (":method", "PUT"), - (":path", "/collections/prompt_vector_store"), - (":authority", "qdrant"), - ("content-type", "application/json"), - ], - Some(b"{ \"vectors\": { \"size\": 1024, \"distance\": \"Cosine\"}}"), - vec![], - Duration::from_secs(5), - ) { - Ok(token_id) => token_id, - Err(e) => { - panic!("Error dispatching HTTP call for init-vector-store: {:?}", e); - } - }; - if self - .callouts - .insert( - token_id, - CallContext::CreateVectorCollection("prompt_vector_store".to_string()), - ) - .is_some() - { - panic!("duplicate token_id") - } - self.metrics - .active_http_calls - .record(self.callouts.len().try_into().unwrap()); - } } - impl Context for FilterContext { fn on_http_call_response( &mut self, @@ -242,37 +191,18 @@ impl Context for FilterContext { body_size: usize, _num_trailers: usize, ) { - let callout_data = self - .callouts - .remove(&token_id) - .expect("invalid token_id: {}"); + debug!("on_http_call_response called with token_id: {:?}", token_id); + let callout_data = self.callouts.remove(&token_id).expect("invalid token_id"); self.metrics .active_http_calls .record(self.callouts.len().try_into().unwrap()); - match callout_data { - CallContext::EmbeddingRequest(EmbeddingRequest { - create_embedding_request, - prompt_target, - }) => { - self.embedding_request_handler(body_size, create_embedding_request, prompt_target) - } - CallContext::StoreVectorEmbeddings(_) => { - self.create_vector_store_points_handler(body_size) - } - CallContext::CreateVectorCollection(_) => { - let mut http_status_code = "Nil".to_string(); - self.get_http_call_response_headers() - .iter() - .for_each(|(k, v)| { - if k == ":status" { - http_status_code.clone_from(v); - } - }); - debug!("CreateVectorCollection response: {}", http_status_code); - } - } + self.embedding_response_handler( + body_size, + callout_data.embedding_type, + callout_data.prompt_target, + ) } } @@ -282,6 +212,13 @@ impl RootContext for FilterContext { if let Some(config_bytes) = self.get_plugin_configuration() { self.config = serde_yaml::from_slice(&config_bytes).unwrap(); + for pt in self.config.clone().unwrap().prompt_targets { + self.prompt_targets + .write() + .unwrap() + .insert(pt.name.clone(), pt.clone()); + } + debug!("set configuration object: {:?}", self.config); if let Some(ratelimits_config) = self @@ -301,6 +238,7 @@ impl RootContext for FilterContext { ratelimit_selector: None, callouts: HashMap::new(), metrics: Rc::clone(&self.metrics), + prompt_targets: Rc::clone(&self.prompt_targets), })) } @@ -314,7 +252,6 @@ impl RootContext for FilterContext { } fn on_tick(&mut self) { - self.init_vector_store(); self.process_prompt_targets(); self.set_tick_period(Duration::from_secs(0)); } diff --git a/envoyfilter/src/stats.rs b/envoyfilter/src/stats.rs index 26755469..250e9017 100644 --- a/envoyfilter/src/stats.rs +++ b/envoyfilter/src/stats.rs @@ -33,7 +33,7 @@ pub trait RecordingMetric: Metric { } } -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] pub struct Counter { id: u32, } @@ -55,7 +55,7 @@ impl Metric for Counter { impl IncrementingMetric for Counter {} -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] pub struct Gauge { id: u32, } diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index 3d155d26..c74cf097 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -1,13 +1,14 @@ use crate::consts::{ - BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, + BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE, }; -use crate::filter_context::WasmMetrics; +use crate::filter_context::{embeddings_store, WasmMetrics}; use crate::ratelimit; use crate::ratelimit::Header; use crate::stats::IncrementingMetric; use crate::tokenizer; +use acap::cos; use http::StatusCode; use log::{debug, error, info, warn}; use open_message_format_embeddings::models::{ @@ -15,31 +16,31 @@ use open_message_format_embeddings::models::{ }; use proxy_wasm::traits::*; use proxy_wasm::types::*; +use public_types::common_types::open_ai::{ChatCompletions, Message}; use public_types::common_types::{ - open_ai::{ChatCompletions, Message}, - SearchPointsRequest, SearchPointsResponse, -}; -use public_types::common_types::{ - BoltFCResponse, BoltFCToolsCall, ToolParameter, ToolParameters, ToolsDefinition, + BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition, + ZeroShotClassificationRequest, ZeroShotClassificationResponse, }; use public_types::configuration::{PromptTarget, PromptType}; use std::collections::HashMap; use std::num::NonZero; use std::rc::Rc; +use std::sync::RwLock; use std::time::Duration; -enum RequestType { - GetEmbedding, - SearchPoints, +enum ResponseHandlerType { + GetEmbeddings, FunctionResolver, - FunctionCallResponse, + FunctionCall, + ZeroShotIntent, } pub struct CallContext { - request_type: RequestType, + response_handler_type: ResponseHandlerType, user_message: Option, prompt_target: Option, request_body: ChatCompletions, + similarity_scores: Option>, } pub struct StreamContext { @@ -47,6 +48,7 @@ pub struct StreamContext { pub ratelimit_selector: Option
, pub callouts: HashMap, pub metrics: Rc, + pub prompt_targets: Rc>>, } impl StreamContext { @@ -61,7 +63,6 @@ impl StreamContext { // However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could // manipulate the body in benign ways e.g., compression. self.set_http_request_header("content-length", None); - // self.set_http_request_header("authorization", None); } fn modify_path_header(&mut self) { @@ -85,7 +86,7 @@ impl StreamContext { }); } - fn send_server_error(&mut self, error: String) { + fn send_server_error(&self, error: String) { debug!("server error occurred: {}", error); self.send_http_response( StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), @@ -103,30 +104,85 @@ impl StreamContext { } }; - let search_points_request = SearchPointsRequest { - vector: embedding_response.data[0].embedding.clone(), - limit: 10, - with_payload: true, - }; + let embeddings_vector = &embedding_response.data[0].embedding; - let json_data: String = match serde_json::to_string(&search_points_request) { - Ok(json_data) => json_data, + debug!( + "embedding model: {}, vector length: {:?}", + embedding_response.model, + embeddings_vector.len() + ); + + let prompt_target_embeddings = match embeddings_store().read() { + Ok(embeddings) => embeddings, Err(e) => { - self.send_server_error(format!("Error serializing search_points_request: {:?}", e)); + let error_message = format!("Error reading embeddings store: {:?}", e); + warn!("{}", error_message); + self.send_server_error(error_message); return; } }; - let path = format!("/collections/{}/points/search", DEFAULT_COLLECTION_NAME); + let prompt_targets = match self.prompt_targets.read() { + Ok(prompt_targets) => prompt_targets, + Err(e) => { + let error_message = format!("Error reading prompt targets: {:?}", e); + warn!("{}", error_message); + self.send_server_error(error_message); + return; + } + }; + + let prompt_target_names = prompt_targets + .iter() + .map(|(name, _)| name.clone()) + .collect(); + + let similarity_scores: Vec<(String, f64)> = prompt_targets + .iter() + .map(|(prompt_name, _prompt_target)| { + let default_embeddings = HashMap::new(); + let pte = prompt_target_embeddings + .get(prompt_name) + .unwrap_or(&default_embeddings); + let description_embeddings = pte.get(&EmbeddingType::Description); + let similarity_score_description = cos::cosine_similarity( + &embeddings_vector, + &description_embeddings.unwrap_or(&vec![0.0]), + ); + (prompt_name.clone(), similarity_score_description) + }) + .collect(); + + debug!( + "similarity scores based on description embeddings match: {:?}", + similarity_scores + ); + + callout_context.similarity_scores = Some(similarity_scores); + + let zero_shot_classification_request = ZeroShotClassificationRequest { + // Need to clone into input because user_message is used below. + input: callout_context.user_message.as_ref().unwrap().clone(), + model: String::from(DEFAULT_INTENT_MODEL), + labels: prompt_target_names, + }; + + let json_data: String = match serde_json::to_string(&zero_shot_classification_request) { + Ok(json_data) => json_data, + Err(error) => { + panic!("Error serializing zero shot request: {}", error); + } + }; let token_id = match self.dispatch_http_call( - "qdrant", + "embeddingserver", vec![ (":method", "POST"), - (":path", &path), - (":authority", "qdrant"), + (":path", "/zeroshot"), + (":authority", "embeddingserver"), ("content-type", "application/json"), ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), ], Some(json_data.as_bytes()), vec![], @@ -134,39 +190,60 @@ impl StreamContext { ) { Ok(token_id) => token_id, Err(e) => { - panic!("Error dispatching HTTP call for get-embeddings: {:?}", e); + panic!( + "Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}", + e + ); } }; + debug!( + "dispatched HTTP call to embedding server for zero-shot-intent-detection token_id={}", + token_id + ); + + callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent; - callout_context.request_type = RequestType::SearchPoints; if self.callouts.insert(token_id, callout_context).is_some() { - panic!("duplicate token_id") + panic!( + "duplicate token_id={} in embedding server requests", + token_id + ) } - self.metrics.active_http_calls.increment(1); } - fn search_points_handler(&mut self, body: Vec, mut callout_context: CallContext) { - let search_points_response: SearchPointsResponse = match serde_json::from_slice(&body) { - Ok(search_points_response) => search_points_response, - Err(e) => { - self.send_server_error(format!( - "Error deserializing search_points_response: {:?}", - e - )); + fn zero_shot_intent_detection_resp_handler( + &mut self, + body: Vec, + mut callout_context: CallContext, + ) { + let zeroshot_intent_response: ZeroShotClassificationResponse = + match serde_json::from_slice(&body) { + Ok(zeroshot_response) => zeroshot_response, + Err(e) => { + warn!( + "Error deserializing zeroshot intent detection response: {:?}", + e + ); + info!("body: {:?}", String::from_utf8(body).unwrap()); + self.resume_http_request(); + return; + } + }; - return; - } - }; + debug!("zeroshot intent response: {:?}", zeroshot_intent_response); - let search_results = &search_points_response.result; + let prompt_target_similarity_score = zeroshot_intent_response.predicted_class_score * 0.7 + + callout_context.similarity_scores.as_ref().unwrap()[0].1 * 0.3; - if search_results.is_empty() { - info!("No prompt target matched"); - self.resume_http_request(); - return; - } + debug!( + "similarity score: {}, intent score: {}, description embedding score: {}", + prompt_target_similarity_score, + zeroshot_intent_response.predicted_class_score, + callout_context.similarity_scores.as_ref().unwrap()[0].1 + ); + + let prompt_target_name = zeroshot_intent_response.predicted_class.clone(); - info!("similarity score: {}", search_results[0].score); // Check to see who responded to user message. This will help us identify if control should be passed to Bolt FC or not. // If the last message was from Bolt FC, then Bolt FC is handling the conversation (possibly for parameter collection). let mut bolt_assistant = false; @@ -175,7 +252,6 @@ impl StreamContext { let latest_assistant_message = &messages[messages.len() - 2]; if let Some(model) = latest_assistant_message.model.as_ref() { if model.starts_with("Bolt") { - info!("Bolt assistant message found"); bolt_assistant = true; } } @@ -183,23 +259,30 @@ impl StreamContext { info!("no assistant message found, probably first interaction"); } - if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant { - info!( - "prompt target below threshold: {}", - DEFAULT_PROMPT_TARGET_THRESHOLD - ); - self.resume_http_request(); - return; - } - let prompt_target_str = search_results[0].payload.get("prompt-target").unwrap(); - let prompt_target: PromptTarget = match serde_json::from_slice(prompt_target_str.as_bytes()) - { - Ok(prompt_target) => prompt_target, - Err(e) => { - self.send_server_error(format!("Error deserializing prompt_target: {:?}", e)); + // check to ensure that the prompt target similarity score is above the threshold + if prompt_target_similarity_score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant { + // if bolt fc responded to the user message, then we don't need to check the similarity score + // it may be that bolt fc is handling the conversation for parameter collection + if bolt_assistant { + info!("bolt assistant is handling the conversation"); + } else { + info!( + "prompt target below threshold: {}, continue conversation with user", + prompt_target_similarity_score, + ); + self.resume_http_request(); return; } - }; + } + + let prompt_target = self + .prompt_targets + .read() + .unwrap() + .get(&prompt_target_name) + .unwrap() + .clone(); + info!( "prompt_target name: {:?}, type: {:?}", prompt_target.name, prompt_target.prompt_type @@ -231,7 +314,7 @@ impl StreamContext { let tools_defintion: ToolsDefinition = ToolsDefinition { name: prompt_target.name.clone(), - description: prompt_target.description.clone().unwrap_or("".to_string()), + description: prompt_target.description.clone(), parameters: tools_parameters, }; @@ -283,7 +366,7 @@ impl StreamContext { BOLT_FC_CLUSTER, token_id ); - callout_context.request_type = RequestType::FunctionResolver; + callout_context.response_handler_type = ResponseHandlerType::FunctionResolver; callout_context.prompt_target = Some(prompt_target); if self.callouts.insert(token_id, callout_context).is_some() { panic!("duplicate token_id") @@ -342,7 +425,7 @@ impl StreamContext { { warn!("boltfc did not extract required parameter: {}", param.name); return self.send_http_response( - StatusCode::BAD_REQUEST.as_u16().into(), + StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), vec![], Some("missing required parameter".as_bytes()), ); @@ -380,7 +463,7 @@ impl StreamContext { } }; - callout_context.request_type = RequestType::FunctionCallResponse; + callout_context.response_handler_type = ResponseHandlerType::FunctionCall; if self.callouts.insert(token_id, callout_context).is_some() { panic!("duplicate token_id") } @@ -468,7 +551,6 @@ impl StreamContext { } } - debug!("sending request to openai: msg {}", json_string); self.set_http_request_body(0, json_string.len(), &json_string.into_bytes()); self.resume_http_request(); } @@ -582,10 +664,11 @@ impl HttpContext for StreamContext { ); let call_context = CallContext { - request_type: RequestType::GetEmbedding, + response_handler_type: ResponseHandlerType::GetEmbeddings, user_message: Some(user_message), prompt_target: None, request_body: deserialized_body, + similarity_scores: None, }; if self.callouts.insert(token_id, call_context).is_some() { panic!( @@ -593,6 +676,7 @@ impl HttpContext for StreamContext { token_id ) } + self.metrics.active_http_calls.increment(1); Action::Pause @@ -611,18 +695,24 @@ impl Context for StreamContext { self.metrics.active_http_calls.increment(-1); if let Some(body) = self.get_http_call_response_body(0, body_size) { - match callout_context.request_type { - RequestType::GetEmbedding => self.embeddings_handler(body, callout_context), - RequestType::SearchPoints => self.search_points_handler(body, callout_context), - RequestType::FunctionResolver => { + match callout_context.response_handler_type { + ResponseHandlerType::GetEmbeddings => { + self.embeddings_handler(body, callout_context) + } + ResponseHandlerType::FunctionResolver => { self.function_resolver_handler(body, callout_context) } - RequestType::FunctionCallResponse => { + ResponseHandlerType::FunctionCall => { self.function_call_response_handler(body, callout_context) } + ResponseHandlerType::ZeroShotIntent => { + self.zero_shot_intent_detection_resp_handler(body, callout_context) + } } } else { - warn!("No response body in inline HTTP request"); + let error_message = "No response body in inline HTTP request"; + warn!("{}", error_message); + self.send_server_error(error_message.to_owned()); } } } diff --git a/envoyfilter/tests/integration.rs b/envoyfilter/tests/integration.rs index 724e9b6b..1c4f97cb 100644 --- a/envoyfilter/tests/integration.rs +++ b/envoyfilter/tests/integration.rs @@ -8,16 +8,10 @@ use proxy_wasm_test_framework::tester::{self, Tester}; use proxy_wasm_test_framework::types::{ Action, BufferType, LogLevel, MapType, MetricType, ReturnType, }; -use public_types::{ - common_types::{ - open_ai::Message, BoltFCResponse, BoltFCToolsCall, IntOrString, ToolCallDetail, - }, - configuration::{self, Endpoint, PromptTarget}, -}; -use public_types::{ - common_types::{SearchPointResult, SearchPointsResponse}, - configuration::Configuration, +use public_types::common_types::{ + open_ai::Message, BoltFCResponse, BoltFCToolsCall, IntOrString, ToolCallDetail, }; +use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration}; use serial_test::serial; use std::collections::HashMap; use std::path::Path; @@ -118,59 +112,39 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&embeddings_response_buffer)) - .expect_http_call(Some("qdrant"), None, None, None, None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_http_call(Some("embeddingserver"), None, None, None, None) .returning(Some(2)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) .unwrap(); - let prompt_target = PromptTarget { - name: String::from("test-prompt-target"), - description: None, - prompt_type: configuration::PromptType::FunctionResolver, - few_shot_examples: vec![], - parameters: Some(vec![configuration::Parameter { - name: String::from("test-entity"), - parameter_type: Some(String::from("string")), - description: String::from("test-description"), - required: Some(true), - }]), - endpoint: Some(Endpoint { - cluster: String::from("test-endpoint-cluster"), - path: None, - method: None, - }), - system_prompt: None, + let zero_shot_response = ZeroShotClassificationResponse { + predicted_class: "weather_forecast".to_string(), + predicted_class_score: 0.1, + scores: HashMap::new(), + model: "test-model".to_string(), }; - let prompt_target_str = serde_json::to_string(&prompt_target).unwrap(); - let search_points_response = SearchPointsResponse { - status: String::new(), - time: 0.0, - result: vec![SearchPointResult { - id: String::new(), - version: 0, - score: 0.7, - payload: HashMap::from([(String::from("prompt-target"), prompt_target_str)]), - }], - }; - let search_points_response_buffer = serde_json::to_string(&search_points_response).unwrap(); + let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap(); module .call_proxy_on_http_call_response( http_context, 2, 0, - search_points_response_buffer.len() as i32, + zeroshot_intent_detection_buffer.len() as i32, 0, ) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&search_points_response_buffer)) - .expect_log(Some(LogLevel::Info), None) - .expect_log(Some(LogLevel::Info), None) + .returning(Some(&zeroshot_intent_detection_buffer)) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), None) .expect_http_call(Some("bolt_fc_1b"), None, None, None, None) .returning(Some(3)) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) .unwrap(); @@ -199,20 +173,22 @@ system_prompt: | prompt_targets: - type: function_resolver name: weather_forecast - few_shot_examples: - - what is the weather in New York? + description: This resolver provides weather forecast information. endpoint: cluster: weatherhost path: /weather - entities: - - name: location + parameters: + - name: city required: true - description: "The location for which the weather is requested" + description: The city for which the weather forecast is requested. + - name: days + description: The number of days for which the weather forecast is requested. + - name: units + description: The units in which the weather forecast is requested. - type: function_resolver name: weather_forecast_2 - few_shot_examples: - - what is the weather in New York? + description: This resolver provides weather forecast information. endpoint: cluster: weatherhost path: /weather @@ -450,10 +426,10 @@ fn request_ratelimited() { normal_flow(&mut module, filter_context, http_context); let tool_call_detail = vec![ToolCallDetail { - name: String::from("test-tool"), + name: String::from("weather_forecast"), arguments: HashMap::from([( - String::from("test-entity"), - IntOrString::Text(String::from("test-value")), + String::from("city"), + IntOrString::Text(String::from("seattle")), )]), }]; @@ -485,7 +461,7 @@ fn request_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_http_call(Some("test-endpoint-cluster"), None, None, None, None) + .expect_http_call(Some("weatherhost"), None, None, None, None) .returning(Some(4)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) @@ -555,10 +531,10 @@ fn request_not_ratelimited() { normal_flow(&mut module, filter_context, http_context); let tool_call_detail = vec![ToolCallDetail { - name: String::from("test-tool"), + name: String::from("weather_forecast"), arguments: HashMap::from([( - String::from("test-entity"), - IntOrString::Text(String::from("test-value")), + String::from("city"), + IntOrString::Text(String::from("seattle")), )]), }]; @@ -590,7 +566,7 @@ fn request_not_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_http_call(Some("test-endpoint-cluster"), None, None, None, None) + .expect_http_call(Some("weatherhost"), None, None, None, None) .returning(Some(4)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) @@ -608,7 +584,7 @@ fn request_not_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) - .expect_log(Some(LogLevel::Debug), None) + // .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) .unwrap(); } diff --git a/public-types/src/common_types.rs b/public-types/src/common_types.rs index 838aa2b8..16d4bd8a 100644 --- a/public-types/src/common_types.rs +++ b/public-types/src/common_types.rs @@ -1,14 +1,18 @@ use crate::configuration::PromptTarget; -use open_message_format_embeddings::models::CreateEmbeddingRequest; use serde::{Deserialize, Serialize}; use std::collections::HashMap; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct EmbeddingRequest { - pub create_embedding_request: CreateEmbeddingRequest, pub prompt_target: PromptTarget, } +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub enum EmbeddingType { + Name, + Description, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct VectorPoint { pub id: String, @@ -21,21 +25,6 @@ pub struct StoreVectorEmbeddingsRequest { pub points: Vec, } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[allow(clippy::large_enum_variant)] -pub enum CallContext { - EmbeddingRequest(EmbeddingRequest), - StoreVectorEmbeddings(StoreVectorEmbeddingsRequest), - CreateVectorCollection(String), -} - -// https://api.qdrant.tech/master/api-reference/search/points -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SearchPointsRequest { - pub vector: Vec, - pub limit: i32, - pub with_payload: bool, -} #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SearchPointResult { @@ -45,13 +34,6 @@ pub struct SearchPointResult { pub payload: HashMap, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SearchPointsResponse { - pub result: Vec, - pub status: String, - pub time: f64, -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolParameter { #[serde(rename = "type")] @@ -125,3 +107,18 @@ pub mod open_ai { pub model: Option, } } + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ZeroShotClassificationRequest { + pub input: String, + pub labels: Vec, + pub model: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ZeroShotClassificationResponse { + pub predicted_class: String, + pub predicted_class_score: f64, + pub scores: HashMap, + pub model: String, +} diff --git a/public-types/src/configuration.rs b/public-types/src/configuration.rs index 36e6eb8f..eb71f3ae 100644 --- a/public-types/src/configuration.rs +++ b/public-types/src/configuration.rs @@ -99,8 +99,7 @@ pub struct PromptTarget { #[serde(rename = "type")] pub prompt_type: PromptType, pub name: String, - pub description: Option, - pub few_shot_examples: Vec, + pub description: String, pub parameters: Option>, pub endpoint: Option, pub system_prompt: Option,