Simplify Entity struct (#33)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-08-06 17:04:32 -07:00 committed by GitHub
parent 1fa5215753
commit b49fc2f264
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 13 additions and 37 deletions

View file

@ -42,26 +42,19 @@ pub struct Endpoint {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EntityDetail {
pub struct Entity {
pub name: String,
pub required: Option<bool>,
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum EntityType {
Vec(Vec<String>),
Struct(Vec<EntityDetail>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptTarget {
#[serde(rename = "type")]
pub prompt_type: String,
pub name: String,
pub few_shot_examples: Vec<String>,
pub entities: Option<EntityType>,
pub entities: Option<Vec<Entity>>,
pub endpoint: Option<Endpoint>,
pub system_prompt: Option<String>,
}
@ -110,7 +103,7 @@ prompt_targets:
cluster: weatherhost
path: /weather
entities:
- city
- name: city
"#;
#[test]

View file

@ -2,9 +2,7 @@ use crate::common_types::{
open_ai::{ChatCompletions, Message},
NERRequest, NERResponse, SearchPointsRequest, SearchPointsResponse,
};
use crate::configuration::EntityDetail;
use crate::configuration::EntityType;
use crate::configuration::PromptTarget;
use crate::configuration::{Entity, PromptTarget};
use crate::consts::{
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE,
@ -159,10 +157,12 @@ impl StreamContext {
info!("prompt_target name: {:?}", prompt_target.name);
// only extract entity names
let entity_names = get_entity_details(&prompt_target)
.into_iter()
.map(|entity| entity.name)
.collect();
let entity_names: Vec<String> = match prompt_target.entities {
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
Some(ref entities) => entities.iter().map(|entity| entity.name.clone()).collect(),
None => vec![],
};
let ner_request = NERRequest {
input: callout_context.user_message.take().unwrap(),
labels: entity_names,
@ -227,8 +227,9 @@ impl StreamContext {
}
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
let entity_details = get_entity_details(prompt_target);
for entity in entity_details {
let empty_vec: Vec<Entity> = vec![];
for entity in prompt_target.entities.as_ref().unwrap_or(&empty_vec) {
if entity.required.unwrap_or(false) && !request_params.contains_key(&entity.name) {
warn!(
"required entity missing or score of entity was too low: {}",
@ -488,21 +489,3 @@ impl Context for StreamContext {
}
}
}
fn get_entity_details(prompt_target: &PromptTarget) -> Vec<EntityDetail> {
match prompt_target.entities.as_ref() {
Some(EntityType::Vec(entity_names)) => {
let mut entity_details: Vec<EntityDetail> = Vec::new();
for entity_name in entity_names {
entity_details.push(EntityDetail {
name: entity_name.clone(),
required: Some(true),
description: None,
});
}
entity_details
}
Some(EntityType::Struct(entity_details)) => entity_details.clone(),
None => Vec::new(),
}
}