mirror of
https://github.com/katanemo/plano.git
synced 2026-05-14 18:42:38 +02:00
refactor prompt gateway (#204)
This commit is contained in:
parent
dced8a5708
commit
2f374df034
9 changed files with 500 additions and 441 deletions
|
|
@ -1,12 +1,12 @@
|
|||
use common::{common_types::open_ai::Message, consts::USER_ROLE};
|
||||
|
||||
pub fn extract_messages_for_hallucination(messages: &Vec<Message>) -> Vec<String> {
|
||||
pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec<String> {
|
||||
let all_user_messages = messages
|
||||
.iter()
|
||||
.filter(|m| m.role == USER_ROLE)
|
||||
.map(|m| m.content.as_ref().unwrap().clone())
|
||||
.collect::<Vec<String>>();
|
||||
return all_user_messages;
|
||||
all_user_messages
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -17,7 +17,7 @@ mod test {
|
|||
|
||||
#[test]
|
||||
fn test_hallucination_message() {
|
||||
let test_str = r#"
|
||||
let test_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -32,8 +32,8 @@ mod test {
|
|||
]
|
||||
"#;
|
||||
|
||||
let messages: Vec<Message> = serde_json::from_str(test_str).unwrap();
|
||||
let messages_for_halluncination = extract_messages_for_hallucination(&messages);
|
||||
assert_eq!(messages_for_halluncination.len(), 2);
|
||||
let messages: Vec<Message> = serde_json::from_str(test_str).unwrap();
|
||||
let messages_for_halluncination = extract_messages_for_hallucination(&messages);
|
||||
assert_eq!(messages_for_halluncination.len(), 2);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue