Add support for json based content types in Message (#480)

This commit is contained in:
Adil Hafeez 2025-05-23 00:51:53 -07:00 committed by GitHub
parent f5e77bbe65
commit 218e9c540d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 314 additions and 121 deletions

View file

@ -1,6 +1,7 @@
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
use core::panic;
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
@ -154,12 +155,54 @@ pub struct StreamOptions {
pub include_usage: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum MultiPartContentType {
#[serde(rename = "text")]
Text,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct MultiPartContent {
pub text: Option<String>,
#[serde(rename = "type")]
pub content_type: MultiPartContentType,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum ContentType {
Text(String),
MultiPart(Vec<MultiPartContent>),
}
impl Display for ContentType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ContentType::Text(text) => write!(f, "{}", text),
ContentType::MultiPart(multi_part) => {
let text_parts: Vec<String> = multi_part
.iter()
.filter_map(|part| {
if part.content_type == MultiPartContentType::Text {
part.text.clone()
} else {
panic!("Unsupported content type: {:?}", part.content_type);
}
})
.collect();
let combined_text = text_parts.join("\n");
write!(f, "{}", combined_text)
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
pub content: Option<ContentType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
@ -237,7 +280,7 @@ impl ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
role: ASSISTANT_ROLE.to_string(),
content: Some(message),
content: Some(ContentType::Text(message)),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
@ -379,6 +422,8 @@ pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
#[cfg(test)]
mod test {
use crate::api::open_ai::{ChatCompletionsRequest, ContentType, MultiPartContentType};
use super::{ChatCompletionStreamResponseServerEvents, Message};
use pretty_assertions::assert_eq;
use std::collections::HashMap;
@ -448,7 +493,9 @@ mod test {
model: "gpt-3.5-turbo".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: Some("What city do you want to know the weather for?".to_string()),
content: Some(ContentType::Text(
"What city do you want to know the weather for?".to_string(),
)),
model: None,
tool_calls: None,
tool_call_id: None,
@ -679,6 +726,111 @@ data: [DONE]
);
}
#[test]
fn test_chat_completions_request() {
const CHAT_COMPLETIONS_REQUEST: &str = r#"
{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "What city do you want to know the weather for?"
}
]
}"#;
let chat_completions_request: ChatCompletionsRequest =
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
assert_eq!(
chat_completions_request.messages[0].content,
Some(ContentType::Text(
"What city do you want to know the weather for?".to_string()
))
);
}
#[test]
fn test_chat_completions_request_text_type() {
const CHAT_COMPLETIONS_REQUEST: &str = r#"
{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What city do you want to know the weather for?"
}
]
}
]
}
"#;
let chat_completions_request: ChatCompletionsRequest =
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
if let Some(ContentType::MultiPart(multi_part_content)) =
chat_completions_request.messages[0].content.as_ref()
{
assert_eq!(multi_part_content[0].content_type, MultiPartContentType::Text);
assert_eq!(
multi_part_content[0].text,
Some("What city do you want to know the weather for?".to_string())
);
} else {
panic!("Expected MultiPartContent");
}
}
#[test]
fn test_chat_completions_request_text_type_array() {
const CHAT_COMPLETIONS_REQUEST: &str = r#"
{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What city do you want to know the weather for?"
},
{
"type": "text",
"text": "hello world"
}
]
}
]
}
"#;
let chat_completions_request: ChatCompletionsRequest =
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
if let Some(ContentType::MultiPart(multi_part_content)) =
chat_completions_request.messages[0].content.as_ref()
{
assert_eq!(multi_part_content.len(), 2);
assert_eq!(multi_part_content[0].content_type, MultiPartContentType::Text);
assert_eq!(
multi_part_content[0].text,
Some("What city do you want to know the weather for?".to_string())
);
assert_eq!(multi_part_content[1].content_type, MultiPartContentType::Text);
assert_eq!(
multi_part_content[1].text,
Some("hello world".to_string())
);
} else {
panic!("Expected MultiPartContent");
}
}
#[test]
fn stream_chunk_parse_claude() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"role":"assistant"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}