mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 12:52:56 +02:00
Simplify Entity struct (#33)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
1fa5215753
commit
b49fc2f264
2 changed files with 13 additions and 37 deletions
|
|
@ -42,26 +42,19 @@ pub struct Endpoint {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EntityDetail {
|
||||
pub struct Entity {
|
||||
pub name: String,
|
||||
pub required: Option<bool>,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum EntityType {
|
||||
Vec(Vec<String>),
|
||||
Struct(Vec<EntityDetail>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptTarget {
|
||||
#[serde(rename = "type")]
|
||||
pub prompt_type: String,
|
||||
pub name: String,
|
||||
pub few_shot_examples: Vec<String>,
|
||||
pub entities: Option<EntityType>,
|
||||
pub entities: Option<Vec<Entity>>,
|
||||
pub endpoint: Option<Endpoint>,
|
||||
pub system_prompt: Option<String>,
|
||||
}
|
||||
|
|
@ -110,7 +103,7 @@ prompt_targets:
|
|||
cluster: weatherhost
|
||||
path: /weather
|
||||
entities:
|
||||
- city
|
||||
- name: city
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -2,9 +2,7 @@ use crate::common_types::{
|
|||
open_ai::{ChatCompletions, Message},
|
||||
NERRequest, NERResponse, SearchPointsRequest, SearchPointsResponse,
|
||||
};
|
||||
use crate::configuration::EntityDetail;
|
||||
use crate::configuration::EntityType;
|
||||
use crate::configuration::PromptTarget;
|
||||
use crate::configuration::{Entity, PromptTarget};
|
||||
use crate::consts::{
|
||||
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE,
|
||||
|
|
@ -159,10 +157,12 @@ impl StreamContext {
|
|||
info!("prompt_target name: {:?}", prompt_target.name);
|
||||
|
||||
// only extract entity names
|
||||
let entity_names = get_entity_details(&prompt_target)
|
||||
.into_iter()
|
||||
.map(|entity| entity.name)
|
||||
.collect();
|
||||
let entity_names: Vec<String> = match prompt_target.entities {
|
||||
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
|
||||
Some(ref entities) => entities.iter().map(|entity| entity.name.clone()).collect(),
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
let ner_request = NERRequest {
|
||||
input: callout_context.user_message.take().unwrap(),
|
||||
labels: entity_names,
|
||||
|
|
@ -227,8 +227,9 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
|
||||
let entity_details = get_entity_details(prompt_target);
|
||||
for entity in entity_details {
|
||||
|
||||
let empty_vec: Vec<Entity> = vec![];
|
||||
for entity in prompt_target.entities.as_ref().unwrap_or(&empty_vec) {
|
||||
if entity.required.unwrap_or(false) && !request_params.contains_key(&entity.name) {
|
||||
warn!(
|
||||
"required entity missing or score of entity was too low: {}",
|
||||
|
|
@ -488,21 +489,3 @@ impl Context for StreamContext {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_entity_details(prompt_target: &PromptTarget) -> Vec<EntityDetail> {
|
||||
match prompt_target.entities.as_ref() {
|
||||
Some(EntityType::Vec(entity_names)) => {
|
||||
let mut entity_details: Vec<EntityDetail> = Vec::new();
|
||||
for entity_name in entity_names {
|
||||
entity_details.push(EntityDetail {
|
||||
name: entity_name.clone(),
|
||||
required: Some(true),
|
||||
description: None,
|
||||
});
|
||||
}
|
||||
entity_details
|
||||
}
|
||||
Some(EntityType::Struct(entity_details)) => entity_details.clone(),
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue