mirror of
https://github.com/katanemo/plano.git
synced 2026-05-08 15:22:43 +02:00
making Messages.Content optional, and having the upstream LLM fail if the right fields aren't set (#699)
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local>
This commit is contained in:
parent
626f556cc6
commit
cdc1d7cee2
17 changed files with 294 additions and 133 deletions
|
|
@ -225,7 +225,7 @@ impl ProviderRequest for ConverseRequest {
|
|||
if let SystemContentBlock::Text { text } = sys_block {
|
||||
openai_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(text.clone()),
|
||||
content: Some(MessageContent::Text(text.clone())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -258,7 +258,7 @@ impl ProviderRequest for ConverseRequest {
|
|||
|
||||
openai_messages.push(Message {
|
||||
role,
|
||||
content: MessageContent::Text(content),
|
||||
content: Some(MessageContent::Text(content)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -279,7 +279,7 @@ impl ProviderRequest for ConverseRequest {
|
|||
for msg in messages {
|
||||
match msg.role {
|
||||
crate::apis::openai::Role::System => {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content {
|
||||
system_blocks.push(SystemContentBlock::Text { text: text.clone() });
|
||||
}
|
||||
}
|
||||
|
|
@ -290,12 +290,13 @@ impl ProviderRequest for ConverseRequest {
|
|||
_ => continue,
|
||||
};
|
||||
|
||||
let content =
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
vec![ContentBlock::Text { text: text.clone() }]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
let content = if let Some(crate::apis::openai::MessageContent::Text(text)) =
|
||||
&msg.content
|
||||
{
|
||||
vec![ContentBlock::Text { text: text.clone() }]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
bedrock_messages.push(crate::apis::amazon_bedrock::Message { role, content });
|
||||
}
|
||||
|
|
|
|||
|
|
@ -584,7 +584,7 @@ impl ProviderRequest for MessagesRequest {
|
|||
let system_text = system_messages
|
||||
.iter()
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
|
|
|
|||
|
|
@ -155,7 +155,8 @@ pub enum Role {
|
|||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub content: MessageContent,
|
||||
/// The contents of the message. Required unless tool_calls is specified (for assistant role)
|
||||
pub content: Option<MessageContent>,
|
||||
pub name: Option<String>,
|
||||
/// Tool calls made by the assistant (only present for assistant role)
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
|
@ -204,8 +205,7 @@ impl ResponseMessage {
|
|||
content: self
|
||||
.content
|
||||
.as_ref()
|
||||
.map(|s| MessageContent::Text(s.clone()))
|
||||
.unwrap_or(MessageContent::Text(String::new())),
|
||||
.map(|s| MessageContent::Text(s.clone())),
|
||||
name: None, // Response messages don't have names in the same way request messages do
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None, // Response messages don't have tool_call_id
|
||||
|
|
@ -233,6 +233,12 @@ impl ExtractText for MessageContent {
|
|||
}
|
||||
}
|
||||
|
||||
impl ExtractText for Option<MessageContent> {
|
||||
fn extract_text(&self) -> String {
|
||||
self.as_ref().map(|c| c.extract_text()).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtractText for Vec<ContentPart> {
|
||||
fn extract_text(&self) -> String {
|
||||
self.iter()
|
||||
|
|
@ -247,23 +253,7 @@ impl ExtractText for Vec<ContentPart> {
|
|||
|
||||
impl Display for MessageContent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MessageContent::Text(text) => write!(f, "{}", text),
|
||||
MessageContent::Parts(parts) => {
|
||||
let text_parts: Vec<String> = parts
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.clone()),
|
||||
ContentPart::ImageUrl { .. } => {
|
||||
// skip image URLs or their data in text representation
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let combined_text = text_parts.join("\n");
|
||||
write!(f, "{}", combined_text)
|
||||
}
|
||||
}
|
||||
write!(f, "{}", self.extract_text())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -622,8 +612,10 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
self.messages.iter().fold(String::new(), |acc, m| {
|
||||
acc + " "
|
||||
+ &match &m.content {
|
||||
let content_text = m
|
||||
.content
|
||||
.as_ref()
|
||||
.map(|content| match content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts
|
||||
.iter()
|
||||
|
|
@ -633,16 +625,18 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" "),
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
acc + " " + &content_text
|
||||
})
|
||||
}
|
||||
|
||||
fn get_recent_user_message(&self) -> Option<String> {
|
||||
self.messages.last().and_then(|msg| {
|
||||
match &msg.content {
|
||||
msg.content.as_ref().and_then(|content| match content {
|
||||
MessageContent::Text(text) => Some(text.clone()),
|
||||
MessageContent::Parts(_) => None, // No user message in parts
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -778,7 +772,8 @@ mod tests {
|
|||
|
||||
let message = &deserialized_request.messages[0];
|
||||
assert_eq!(message.role, Role::User);
|
||||
if let MessageContent::Text(content) = &message.content {
|
||||
assert!(message.content.is_some());
|
||||
if let Some(MessageContent::Text(content)) = &message.content {
|
||||
assert_eq!(content, "Hello, world!");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
@ -822,7 +817,8 @@ mod tests {
|
|||
|
||||
let message = &deserialized_request.messages[0];
|
||||
assert_eq!(message.role, Role::User);
|
||||
if let MessageContent::Text(content) = &message.content {
|
||||
assert!(message.content.is_some());
|
||||
if let Some(MessageContent::Text(content)) = &message.content {
|
||||
assert_eq!(content, "Test message");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
@ -947,7 +943,8 @@ mod tests {
|
|||
// Validate first message (user with multimodal content)
|
||||
let user_message = &deserialized_request.messages[0];
|
||||
assert_eq!(user_message.role, Role::User);
|
||||
if let MessageContent::Parts(ref content_parts) = user_message.content {
|
||||
assert!(user_message.content.is_some());
|
||||
if let Some(MessageContent::Parts(ref content_parts)) = user_message.content {
|
||||
assert_eq!(content_parts.len(), 2);
|
||||
|
||||
// Validate text content part
|
||||
|
|
@ -971,7 +968,8 @@ mod tests {
|
|||
// Validate second message (assistant with tool calls)
|
||||
let assistant_message = &deserialized_request.messages[1];
|
||||
assert_eq!(assistant_message.role, Role::Assistant);
|
||||
if let MessageContent::Text(text) = &assistant_message.content {
|
||||
assert!(assistant_message.content.is_some());
|
||||
if let Some(MessageContent::Text(text)) = &assistant_message.content {
|
||||
assert_eq!(
|
||||
text,
|
||||
"I can see a beautiful cityscape. Let me check the weather for you."
|
||||
|
|
@ -997,7 +995,8 @@ mod tests {
|
|||
// Validate third message (tool response)
|
||||
let tool_message = &deserialized_request.messages[2];
|
||||
assert_eq!(tool_message.role, Role::Tool);
|
||||
if let MessageContent::Text(text) = &tool_message.content {
|
||||
assert!(tool_message.content.is_some());
|
||||
if let Some(MessageContent::Text(text)) = &tool_message.content {
|
||||
assert_eq!(text, "Current weather in New York: 72°F, sunny");
|
||||
} else {
|
||||
panic!("Expected text content for tool message");
|
||||
|
|
@ -1061,6 +1060,62 @@ mod tests {
|
|||
assert!((original_temp - serialized_temp).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_assistant_message_with_tool_calls_no_content() {
|
||||
// Test that assistant messages can have tool_calls without content
|
||||
let json_with_tool_calls_no_content = json!({
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather in San Francisco?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"location\": \"San Francisco, CA\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// Should deserialize successfully
|
||||
let request: ChatCompletionsRequest =
|
||||
serde_json::from_value(json_with_tool_calls_no_content.clone()).unwrap();
|
||||
|
||||
assert_eq!(request.messages.len(), 2);
|
||||
|
||||
// Check user message
|
||||
let user_msg = &request.messages[0];
|
||||
assert_eq!(user_msg.role, Role::User);
|
||||
assert!(user_msg.content.is_some());
|
||||
|
||||
// Check assistant message - should have tool_calls but no content
|
||||
let assistant_msg = &request.messages[1];
|
||||
assert_eq!(assistant_msg.role, Role::Assistant);
|
||||
assert!(assistant_msg.content.is_none());
|
||||
assert!(assistant_msg.tool_calls.is_some());
|
||||
|
||||
let tool_calls = assistant_msg.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].id, "call_123");
|
||||
assert_eq!(tool_calls[0].function.name, "get_weather");
|
||||
|
||||
// Should serialize back without content field
|
||||
let serialized = serde_json::to_value(&request).unwrap();
|
||||
// Verify the assistant message doesn't have a content field in serialized JSON
|
||||
let serialized_assistant_msg = &serialized["messages"][1];
|
||||
assert!(serialized_assistant_msg.get("content").is_none());
|
||||
assert!(serialized_assistant_msg.get("tool_calls").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_provider_trait() {
|
||||
// Test the ApiDefinition trait implementation
|
||||
|
|
@ -1097,7 +1152,7 @@ mod tests {
|
|||
|
||||
let deserialized_user: Message = serde_json::from_value(user_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_user.role, Role::User);
|
||||
if let MessageContent::Text(content) = &deserialized_user.content {
|
||||
if let Some(MessageContent::Text(content)) = &deserialized_user.content {
|
||||
assert_eq!(content, "Hello!");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
@ -1128,7 +1183,7 @@ mod tests {
|
|||
let deserialized_assistant: Message =
|
||||
serde_json::from_value(assistant_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_assistant.role, Role::Assistant);
|
||||
if let MessageContent::Text(content) = &deserialized_assistant.content {
|
||||
if let Some(MessageContent::Text(content)) = &deserialized_assistant.content {
|
||||
assert_eq!(content, "I'll help with that.");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
@ -1154,7 +1209,7 @@ mod tests {
|
|||
|
||||
let deserialized_tool: Message = serde_json::from_value(tool_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_tool.role, Role::Tool);
|
||||
if let MessageContent::Text(content) = &deserialized_tool.content {
|
||||
if let Some(MessageContent::Text(content)) = &deserialized_tool.content {
|
||||
assert_eq!(content, "Weather is sunny");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
@ -1193,7 +1248,7 @@ mod tests {
|
|||
// Test conversion from ResponseMessage to Message
|
||||
let converted = deserialized_response.to_message();
|
||||
assert_eq!(converted.role, Role::Assistant);
|
||||
if let MessageContent::Text(text) = converted.content {
|
||||
if let Some(MessageContent::Text(text)) = converted.content {
|
||||
assert_eq!(text, "Response content");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
|
|||
|
|
@ -1146,7 +1146,7 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
.iter()
|
||||
.filter(|msg| msg.role == crate::apis::openai::Role::System)
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
|
|
@ -1170,7 +1170,8 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
if !input_messages.is_empty() {
|
||||
// If there's only one message, use Text format
|
||||
if input_messages.len() == 1 {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content
|
||||
if let Some(crate::apis::openai::MessageContent::Text(text)) =
|
||||
&input_messages[0].content
|
||||
{
|
||||
self.input = crate::apis::openai_responses::InputParam::Text(text.clone());
|
||||
}
|
||||
|
|
@ -1180,7 +1181,8 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
let combined_text = input_messages
|
||||
.iter()
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content
|
||||
{
|
||||
Some(format!(
|
||||
"{}: {}",
|
||||
match msg.role {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue