diff --git a/envoyfilter/src/configuration.rs b/envoyfilter/src/configuration.rs index 023301cd..b37e0659 100644 --- a/envoyfilter/src/configuration.rs +++ b/envoyfilter/src/configuration.rs @@ -42,26 +42,19 @@ pub struct Endpoint { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EntityDetail { +pub struct Entity { pub name: String, pub required: Option, pub description: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum EntityType { - Vec(Vec), - Struct(Vec), -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PromptTarget { #[serde(rename = "type")] pub prompt_type: String, pub name: String, pub few_shot_examples: Vec, - pub entities: Option, + pub entities: Option>, pub endpoint: Option, pub system_prompt: Option, } @@ -110,7 +103,7 @@ prompt_targets: cluster: weatherhost path: /weather entities: - - city + - name: city "#; #[test] diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index 1ea15818..587dfe0f 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -2,9 +2,7 @@ use crate::common_types::{ open_ai::{ChatCompletions, Message}, NERRequest, NERResponse, SearchPointsRequest, SearchPointsResponse, }; -use crate::configuration::EntityDetail; -use crate::configuration::EntityType; -use crate::configuration::PromptTarget; +use crate::configuration::{Entity, PromptTarget}; use crate::consts::{ DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD, DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE, @@ -159,10 +157,12 @@ impl StreamContext { info!("prompt_target name: {:?}", prompt_target.name); // only extract entity names - let entity_names = get_entity_details(&prompt_target) - .into_iter() - .map(|entity| entity.name) - .collect(); + let entity_names: Vec = match prompt_target.entities { + // Clone is unavoidable here because we don't want to move the values out of the prompt target struct. + Some(ref entities) => entities.iter().map(|entity| entity.name.clone()).collect(), + None => vec![], + }; + let ner_request = NERRequest { input: callout_context.user_message.take().unwrap(), labels: entity_names, @@ -227,8 +227,9 @@ impl StreamContext { } let prompt_target = callout_context.prompt_target.as_ref().unwrap(); - let entity_details = get_entity_details(prompt_target); - for entity in entity_details { + + let empty_vec: Vec = vec![]; + for entity in prompt_target.entities.as_ref().unwrap_or(&empty_vec) { if entity.required.unwrap_or(false) && !request_params.contains_key(&entity.name) { warn!( "required entity missing or score of entity was too low: {}", @@ -488,21 +489,3 @@ impl Context for StreamContext { } } } - -fn get_entity_details(prompt_target: &PromptTarget) -> Vec { - match prompt_target.entities.as_ref() { - Some(EntityType::Vec(entity_names)) => { - let mut entity_details: Vec = Vec::new(); - for entity_name in entity_names { - entity_details.push(EntityDetail { - name: entity_name.clone(), - required: Some(true), - description: None, - }); - } - entity_details - } - Some(EntityType::Struct(entity_details)) => entity_details.clone(), - None => Vec::new(), - } -}