diff --git a/chatbot_ui/app/run_stream.py b/chatbot_ui/app/run_stream.py index dbcf0df3..458508d3 100644 --- a/chatbot_ui/app/run_stream.py +++ b/chatbot_ui/app/run_stream.py @@ -5,8 +5,9 @@ from openai import OpenAI import gradio as gr api_key = os.getenv("OPENAI_API_KEY") +CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1") -client = OpenAI(api_key=api_key) +client = OpenAI(api_key=api_key, base_url=CHAT_COMPLETION_ENDPOINT) def predict(message, history): history_openai_format = [] diff --git a/demos/function_calling/bolt_config.yaml b/demos/function_calling/bolt_config.yaml index 03bf869b..a3d115dd 100644 --- a/demos/function_calling/bolt_config.yaml +++ b/demos/function_calling/bolt_config.yaml @@ -2,6 +2,9 @@ default_prompt_endpoint: "127.0.0.1" load_balancing: "round_robin" timeout_ms: 5000 +overrides: + # confidence threshold for prompt target intent matching + prompt_target_intent_matching_threshold: 0.6 # should not be here embedding_provider: diff --git a/envoyfilter/src/filter_context.rs b/envoyfilter/src/filter_context.rs index 4774b010..9deeb9ec 100644 --- a/envoyfilter/src/filter_context.rs +++ b/envoyfilter/src/filter_context.rs @@ -9,7 +9,7 @@ use open_message_format_embeddings::models::{ use proxy_wasm::traits::*; use proxy_wasm::types::*; use public_types::common_types::EmbeddingType; -use public_types::configuration::{Configuration, PromptTarget}; +use public_types::configuration::{Configuration, Overrides, PromptTarget}; use serde_json::to_string; use std::collections::HashMap; use std::rc::Rc; @@ -45,6 +45,7 @@ pub struct FilterContext { // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. callouts: HashMap, config: Option, + overrides: Rc>, prompt_targets: Rc>>, } @@ -63,6 +64,7 @@ impl FilterContext { config: None, metrics: Rc::new(WasmMetrics::new()), prompt_targets: Rc::new(RwLock::new(HashMap::new())), + overrides: Rc::new(None), } } @@ -212,6 +214,14 @@ impl RootContext for FilterContext { if let Some(config_bytes) = self.get_plugin_configuration() { self.config = serde_yaml::from_slice(&config_bytes).unwrap(); + if let Some(overrides_config) = self + .config + .as_mut() + .and_then(|config| config.overrides.as_mut()) + { + self.overrides = Rc::new(Some(std::mem::take(overrides_config))); + } + for pt in self.config.clone().unwrap().prompt_targets { self.prompt_targets .write() @@ -237,6 +247,7 @@ impl RootContext for FilterContext { context_id, Rc::clone(&self.metrics), Rc::clone(&self.prompt_targets), + Rc::clone(&self.overrides), ))) } diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index 771772ec..a2b4586d 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -24,7 +24,7 @@ use public_types::common_types::{ BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition, ZeroShotClassificationRequest, ZeroShotClassificationResponse, }; -use public_types::configuration::{PromptTarget, PromptType}; +use public_types::configuration::{Overrides, PromptTarget, PromptType}; use std::collections::HashMap; use std::num::NonZero; use std::rc::Rc; @@ -50,6 +50,7 @@ pub struct StreamContext { pub context_id: u32, pub metrics: Rc, pub prompt_targets: Rc>>, + pub overrides: Rc>, callouts: HashMap, host_header: Option, ratelimit_selector: Option
, @@ -63,6 +64,7 @@ impl StreamContext { context_id: u32, metrics: Rc, prompt_targets: Rc>>, + overrides: Rc>, ) -> Self { StreamContext { context_id, @@ -74,6 +76,7 @@ impl StreamContext { streaming_response: false, response_tokens: 0, chat_completions_request: false, + overrides, } } fn save_host_header(&mut self) { @@ -263,7 +266,7 @@ impl StreamContext { + callout_context.similarity_scores.as_ref().unwrap()[0].1 * 0.3; debug!( - "similarity score: {}, intent score: {}, description embedding score: {}", + "similarity score: {:.3}, intent score: {:.3}, description embedding score: {:.3}", prompt_target_similarity_score, zeroshot_intent_response.predicted_class_score, callout_context.similarity_scores.as_ref().unwrap()[0].1 @@ -286,16 +289,28 @@ impl StreamContext { info!("no assistant message found, probably first interaction"); } + // get prompt target similarity thresold from overrides + let prompt_target_intent_matching_threshold = match self.overrides.as_ref() { + Some(overrides) => match overrides.prompt_target_intent_matching_threshold { + Some(threshold) => threshold, + None => DEFAULT_PROMPT_TARGET_THRESHOLD, + }, + None => DEFAULT_PROMPT_TARGET_THRESHOLD, + }; + // check to ensure that the prompt target similarity score is above the threshold - if prompt_target_similarity_score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant { + if prompt_target_similarity_score < prompt_target_intent_matching_threshold + && !bolt_assistant + { // if bolt fc responded to the user message, then we don't need to check the similarity score // it may be that bolt fc is handling the conversation for parameter collection if bolt_assistant { info!("bolt assistant is handling the conversation"); } else { info!( - "prompt target below threshold: {}, continue conversation with user", + "prompt target below limit: {:.3}, threshold: {:.3}, continue conversation with user", prompt_target_similarity_score, + prompt_target_intent_matching_threshold ); self.resume_http_request(); return; diff --git a/public_types/src/configuration.rs b/public_types/src/configuration.rs index 3d127782..40c481f1 100644 --- a/public_types/src/configuration.rs +++ b/public_types/src/configuration.rs @@ -1,10 +1,16 @@ use serde::{Deserialize, Serialize}; +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct Overrides { + pub prompt_target_intent_matching_threshold: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Configuration { pub default_prompt_endpoint: String, pub load_balancing: LoadBalancing, pub timeout_ms: u64, + pub overrides: Option, pub embedding_provider: EmbeddingProviver, pub llm_providers: Vec, pub prompt_guards: Option,