mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 08:46:24 +02:00
Add support for json based content types in Message (#480)
This commit is contained in:
parent
f5e77bbe65
commit
218e9c540d
16 changed files with 314 additions and 121 deletions
|
|
@ -1,6 +1,7 @@
|
|||
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use serde_yaml::Value;
|
||||
use core::panic;
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
fmt::Display,
|
||||
|
|
@ -154,12 +155,54 @@ pub struct StreamOptions {
|
|||
pub include_usage: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum MultiPartContentType {
|
||||
#[serde(rename = "text")]
|
||||
Text,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct MultiPartContent {
|
||||
pub text: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
pub content_type: MultiPartContentType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum ContentType {
|
||||
Text(String),
|
||||
MultiPart(Vec<MultiPartContent>),
|
||||
}
|
||||
|
||||
impl Display for ContentType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ContentType::Text(text) => write!(f, "{}", text),
|
||||
ContentType::MultiPart(multi_part) => {
|
||||
let text_parts: Vec<String> = multi_part
|
||||
.iter()
|
||||
.filter_map(|part| {
|
||||
if part.content_type == MultiPartContentType::Text {
|
||||
part.text.clone()
|
||||
} else {
|
||||
panic!("Unsupported content type: {:?}", part.content_type);
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let combined_text = text_parts.join("\n");
|
||||
write!(f, "{}", combined_text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
pub content: Option<ContentType>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
|
@ -237,7 +280,7 @@ impl ChatCompletionsResponse {
|
|||
choices: vec![Choice {
|
||||
message: Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: Some(message),
|
||||
content: Some(ContentType::Text(message)),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -379,6 +422,8 @@ pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::api::open_ai::{ChatCompletionsRequest, ContentType, MultiPartContentType};
|
||||
|
||||
use super::{ChatCompletionStreamResponseServerEvents, Message};
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -448,7 +493,9 @@ mod test {
|
|||
model: "gpt-3.5-turbo".to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What city do you want to know the weather for?".to_string()),
|
||||
content: Some(ContentType::Text(
|
||||
"What city do you want to know the weather for?".to_string(),
|
||||
)),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -679,6 +726,111 @@ data: [DONE]
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_request() {
|
||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What city do you want to know the weather for?"
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
|
||||
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
|
||||
assert_eq!(
|
||||
chat_completions_request.messages[0].content,
|
||||
Some(ContentType::Text(
|
||||
"What city do you want to know the weather for?".to_string()
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_request_text_type() {
|
||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What city do you want to know the weather for?"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
|
||||
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
|
||||
if let Some(ContentType::MultiPart(multi_part_content)) =
|
||||
chat_completions_request.messages[0].content.as_ref()
|
||||
{
|
||||
assert_eq!(multi_part_content[0].content_type, MultiPartContentType::Text);
|
||||
assert_eq!(
|
||||
multi_part_content[0].text,
|
||||
Some("What city do you want to know the weather for?".to_string())
|
||||
);
|
||||
} else {
|
||||
panic!("Expected MultiPartContent");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_request_text_type_array() {
|
||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What city do you want to know the weather for?"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hello world"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
|
||||
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
|
||||
if let Some(ContentType::MultiPart(multi_part_content)) =
|
||||
chat_completions_request.messages[0].content.as_ref()
|
||||
{
|
||||
assert_eq!(multi_part_content.len(), 2);
|
||||
assert_eq!(multi_part_content[0].content_type, MultiPartContentType::Text);
|
||||
assert_eq!(
|
||||
multi_part_content[0].text,
|
||||
Some("What city do you want to know the weather for?".to_string())
|
||||
);
|
||||
assert_eq!(multi_part_content[1].content_type, MultiPartContentType::Text);
|
||||
assert_eq!(
|
||||
multi_part_content[1].text,
|
||||
Some("hello world".to_string())
|
||||
);
|
||||
} else {
|
||||
panic!("Expected MultiPartContent");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse_claude() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"role":"assistant"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue