add ability to override default values from config (#58)

This commit is contained in:
Adil Hafeez 2024-09-17 22:37:58 -07:00 committed by GitHub
parent 9f3c845610
commit 3135ba8eae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 42 additions and 6 deletions

View file

@ -9,7 +9,7 @@ use open_message_format_embeddings::models::{
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::EmbeddingType;
use public_types::configuration::{Configuration, PromptTarget};
use public_types::configuration::{Configuration, Overrides, PromptTarget};
use serde_json::to_string;
use std::collections::HashMap;
use std::rc::Rc;
@ -45,6 +45,7 @@ pub struct FilterContext {
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, CallContext>,
config: Option<Configuration>,
overrides: Rc<Option<Overrides>>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
}
@ -63,6 +64,7 @@ impl FilterContext {
config: None,
metrics: Rc::new(WasmMetrics::new()),
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
overrides: Rc::new(None),
}
}
@ -212,6 +214,14 @@ impl RootContext for FilterContext {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
if let Some(overrides_config) = self
.config
.as_mut()
.and_then(|config| config.overrides.as_mut())
{
self.overrides = Rc::new(Some(std::mem::take(overrides_config)));
}
for pt in self.config.clone().unwrap().prompt_targets {
self.prompt_targets
.write()
@ -237,6 +247,7 @@ impl RootContext for FilterContext {
context_id,
Rc::clone(&self.metrics),
Rc::clone(&self.prompt_targets),
Rc::clone(&self.overrides),
)))
}

View file

@ -24,7 +24,7 @@ use public_types::common_types::{
BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition,
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
};
use public_types::configuration::{PromptTarget, PromptType};
use public_types::configuration::{Overrides, PromptTarget, PromptType};
use std::collections::HashMap;
use std::num::NonZero;
use std::rc::Rc;
@ -50,6 +50,7 @@ pub struct StreamContext {
pub context_id: u32,
pub metrics: Rc<WasmMetrics>,
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
pub overrides: Rc<Option<Overrides>>,
callouts: HashMap<u32, CallContext>,
host_header: Option<String>,
ratelimit_selector: Option<Header>,
@ -63,6 +64,7 @@ impl StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
overrides: Rc<Option<Overrides>>,
) -> Self {
StreamContext {
context_id,
@ -74,6 +76,7 @@ impl StreamContext {
streaming_response: false,
response_tokens: 0,
chat_completions_request: false,
overrides,
}
}
fn save_host_header(&mut self) {
@ -263,7 +266,7 @@ impl StreamContext {
+ callout_context.similarity_scores.as_ref().unwrap()[0].1 * 0.3;
debug!(
"similarity score: {}, intent score: {}, description embedding score: {}",
"similarity score: {:.3}, intent score: {:.3}, description embedding score: {:.3}",
prompt_target_similarity_score,
zeroshot_intent_response.predicted_class_score,
callout_context.similarity_scores.as_ref().unwrap()[0].1
@ -286,16 +289,28 @@ impl StreamContext {
info!("no assistant message found, probably first interaction");
}
// get prompt target similarity thresold from overrides
let prompt_target_intent_matching_threshold = match self.overrides.as_ref() {
Some(overrides) => match overrides.prompt_target_intent_matching_threshold {
Some(threshold) => threshold,
None => DEFAULT_PROMPT_TARGET_THRESHOLD,
},
None => DEFAULT_PROMPT_TARGET_THRESHOLD,
};
// check to ensure that the prompt target similarity score is above the threshold
if prompt_target_similarity_score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant {
if prompt_target_similarity_score < prompt_target_intent_matching_threshold
&& !bolt_assistant
{
// if bolt fc responded to the user message, then we don't need to check the similarity score
// it may be that bolt fc is handling the conversation for parameter collection
if bolt_assistant {
info!("bolt assistant is handling the conversation");
} else {
info!(
"prompt target below threshold: {}, continue conversation with user",
"prompt target below limit: {:.3}, threshold: {:.3}, continue conversation with user",
prompt_target_similarity_score,
prompt_target_intent_matching_threshold
);
self.resume_http_request();
return;