diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index eb270510..5fd13b98 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -420,56 +420,142 @@ pub struct StreamOptions { mod tests { use super::*; use serde_json::json; + use std::collections::HashMap; #[test] - fn test_improved_naming_structure() { - // Test that the new clean naming works well - let message = Message { - role: Role::User, - content: MessageContent::Text("Hello, world!".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + fn test_required_fields() { + // Test ChatCompletionsRequest with only required fields + let minimal_request = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![Message { + role: Role::User, + content: MessageContent::Text("Hello, world!".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + // All other fields are optional + frequency_penalty: None, + function_call: None, + functions: None, + logit_bias: None, + logprobs: None, + max_completion_tokens: None, + max_tokens: None, + modalities: None, + metadata: None, + n: None, + presence_penalty: None, + parallel_tool_calls: None, + prediction: None, + response_format: None, + seed: None, + service_tier: None, + stop: None, + store: None, + stream: None, + stream_options: None, + temperature: None, + tool_choice: None, + tools: None, + top_p: None, + top_logprobs: None, + user: None, }; - let request = ChatCompletionsRequest { + // Test serialization of minimal request + let json = serde_json::to_value(&minimal_request).unwrap(); + let obj = json.as_object().unwrap(); + + // Required fields should be present + assert_eq!(obj["model"], "gpt-4"); + assert!(obj.contains_key("messages")); + assert_eq!(obj["messages"].as_array().unwrap().len(), 1); + + // Test message structure + let message = &obj["messages"].as_array().unwrap()[0]; + assert_eq!(message["role"], "user"); + assert_eq!(message["content"], "Hello, world!"); + + // Optional fields should not be present + assert!(!obj.contains_key("temperature")); + assert!(!obj.contains_key("max_tokens")); + assert!(!obj.contains_key("stream")); + } + + #[test] + fn test_optional_fields_serialization() { + // Test that optional fields work correctly and None fields are skipped + let request_with_options = ChatCompletionsRequest { model: "gpt-4".to_string(), - messages: vec![message], + messages: vec![Message { + role: Role::User, + content: MessageContent::Text("Test message".to_string()), + name: Some("test_user".to_string()), // Optional field with value + tool_calls: None, // Optional field as None + tool_call_id: None, + }], temperature: Some(0.7), + max_tokens: Some(150), + stream: Some(true), + stream_options: Some(StreamOptions { + include_usage: Some(true), + }), + metadata: Some(HashMap::from([ + ("user_id".to_string(), "123".to_string()), + ])), + // These should be None and skipped top_p: None, - max_tokens: Some(100), - max_completion_tokens: None, - stream: Some(false), - stream_options: None, + frequency_penalty: None, + presence_penalty: None, stop: None, tools: None, tool_choice: None, parallel_tool_calls: None, user: None, - presence_penalty: None, - frequency_penalty: None, logit_bias: None, logprobs: None, - top_logprobs: None, + max_completion_tokens: None, + modalities: None, n: None, - seed: None, + prediction: None, response_format: None, + seed: None, service_tier: None, store: None, - metadata: None, - modalities: None, + top_logprobs: None, function_call: None, functions: None, - prediction: None, }; - assert_eq!(request.model, "gpt-4"); - assert_eq!(request.messages.len(), 1); - assert_eq!(request.messages[0].role, Role::User); + let json = serde_json::to_value(&request_with_options).unwrap(); + let obj = json.as_object().unwrap(); + + // Fields with Some values should be present + assert!((obj["temperature"].as_f64().unwrap() - 0.7).abs() < 1e-6); + assert_eq!(obj["max_tokens"], 150); + assert_eq!(obj["stream"], true); + assert!(obj.contains_key("stream_options")); + assert!(obj.contains_key("metadata")); + + // Message name should be present + let message = &obj["messages"].as_array().unwrap()[0]; + assert_eq!(message["name"], "test_user"); + assert!(!message.as_object().unwrap().contains_key("tool_calls")); + + // None fields should be skipped + assert!(!obj.contains_key("top_p")); + assert!(!obj.contains_key("frequency_penalty")); + assert!(!obj.contains_key("presence_penalty")); + assert!(!obj.contains_key("stop")); + assert!(!obj.contains_key("tools")); } #[test] - fn test_tool_structure() { + fn test_nested_types_serialization() { + // Test tools, message parts, static content, and streaming deltas + + // Test tool serialization let tool = Tool { tool_type: "function".to_string(), function: Function { @@ -481,381 +567,18 @@ mod tests { "location": {"type": "string"} } }), - strict: None, + strict: Some(true), }, }; - assert_eq!(tool.function.name, "get_weather"); - assert!(tool.function.description.is_some()); - } + let tool_json = serde_json::to_value(&tool).unwrap(); + assert_eq!(tool_json["type"], "function"); + assert_eq!(tool_json["function"]["name"], "get_weather"); + assert!(tool_json["function"].as_object().unwrap().contains_key("description")); + assert_eq!(tool_json["function"]["strict"], true); - #[test] - fn test_content_parts() { - let text_part = ContentPart::Text { - text: "Describe this image".to_string(), - }; - - let image_part = ContentPart::ImageUrl { - image_url: ImageUrl { - url: "https://example.com/image.jpg".to_string(), - detail: Some("high".to_string()), - }, - }; - - let content = MessageContent::Parts(vec![text_part, image_part]); - - if let MessageContent::Parts(parts) = content { - assert_eq!(parts.len(), 2); - - if let ContentPart::Text { text } = &parts[0] { - assert_eq!(text, "Describe this image"); - } else { - panic!("Expected text part"); - } - - if let ContentPart::ImageUrl { image_url } = &parts[1] { - assert_eq!(image_url.url, "https://example.com/image.jpg"); - assert_eq!(image_url.detail, Some("high".to_string())); - } else { - panic!("Expected image part"); - } - } else { - panic!("Expected parts content"); - } - } - - #[test] - fn test_api_enum() { - let api = OpenAIApi::ChatCompletions; - assert_eq!(api.endpoint(), "/v1/chat/completions"); - assert!(api.supports_streaming()); - assert!(api.supports_tools()); - assert!(api.supports_vision()); - - let found_api = OpenAIApi::from_endpoint("/v1/chat/completions"); - assert_eq!(found_api, Some(OpenAIApi::ChatCompletions)); - } - - #[test] - fn test_api_specific_naming() { - // Test that the API-specific naming is clear and intuitive - let request = ChatCompletionsRequest { - model: "gpt-4".to_string(), - messages: vec![Message { - role: Role::User, - content: MessageContent::Text("Test message".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, - }], - temperature: Some(0.7), - top_p: None, - max_tokens: Some(100), - max_completion_tokens: None, - stream: Some(false), - stream_options: None, - stop: None, - tools: None, - tool_choice: None, - parallel_tool_calls: None, - user: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - logprobs: None, - top_logprobs: None, - n: None, - seed: None, - response_format: None, - service_tier: None, - store: None, - metadata: None, - modalities: None, - function_call: None, - functions: None, - prediction: None, - }; - - let response = ChatCompletionsResponse { - id: "chatcmpl-123".to_string(), - object: "chat.completion".to_string(), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![Choice { - index: 0, - message: ResponseMessage { - role: Role::Assistant, - content: Some("Hello!".to_string()), - refusal: None, - annotations: None, - audio: None, - function_call: None, - tool_calls: None, - }, - finish_reason: Some(FinishReason::Stop), - logprobs: None, - }], - usage: Usage { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - prompt_tokens_details: None, - completion_tokens_details: None, - }, - system_fingerprint: None, - }; - - let stream_response = ChatCompletionsStreamResponse { - id: "chatcmpl-stream-123".to_string(), - object: "chat.completion.chunk".to_string(), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![StreamChoice { - index: 0, - delta: MessageDelta { - role: Some(Role::Assistant), - content: Some("Hello".to_string()), - refusal: None, - function_call: None, - tool_calls: None, - }, - finish_reason: None, - logprobs: None, - }], - usage: None, - system_fingerprint: None, - service_tier: None, - }; - - // Assert the naming makes sense - assert_eq!(request.model, "gpt-4"); - assert_eq!(response.choices[0].message.role, Role::Assistant); - assert_eq!(stream_response.object, "chat.completion.chunk"); - } - - #[test] - fn test_skip_serializing_none_annotations() { - // Test that skip_serializing_none works correctly - let message = Message { - role: Role::User, - content: MessageContent::Text("Hello".to_string()), - name: None, // Should be skipped - tool_calls: None, // Should be skipped - tool_call_id: None, // Should be skipped - }; - - let json = serde_json::to_value(&message).unwrap(); - - // Verify that None fields are not present in the JSON - assert!(!json.as_object().unwrap().contains_key("name")); - assert!(!json.as_object().unwrap().contains_key("tool_calls")); - assert!(!json.as_object().unwrap().contains_key("tool_call_id")); - - // Verify that required fields are present - assert!(json.as_object().unwrap().contains_key("role")); - assert!(json.as_object().unwrap().contains_key("content")); - - // Test with Some values to ensure they are included - let message_with_name = Message { - role: Role::Assistant, - content: MessageContent::Text("Hello back".to_string()), - name: Some("assistant".to_string()), - tool_calls: None, - tool_call_id: None, - }; - - let json_with_name = serde_json::to_value(&message_with_name).unwrap(); - assert!(json_with_name.as_object().unwrap().contains_key("name")); - assert!(!json_with_name.as_object().unwrap().contains_key("tool_calls")); - } - - #[test] - fn test_api_provider_trait_implementation() { - use super::ApiDefinition; - - // Test that OpenAIApi implements ApiDefinition trait correctly - let api = OpenAIApi::ChatCompletions; - - // Test trait methods - assert_eq!(ApiDefinition::endpoint(&api), "/v1/chat/completions"); - assert!(ApiDefinition::supports_streaming(&api)); - assert!(ApiDefinition::supports_tools(&api)); - assert!(ApiDefinition::supports_vision(&api)); - - // Test from_endpoint trait method - let found_api = OpenAIApi::from_endpoint("/v1/chat/completions"); - assert_eq!(found_api, Some(OpenAIApi::ChatCompletions)); - - let not_found = OpenAIApi::from_endpoint("/v1/unknown"); - assert_eq!(not_found, None); - } - - #[test] - fn test_new_api_fields() { - // Test that new fields are properly handled - let request = ChatCompletionsRequest { - model: "gpt-4".to_string(), - messages: vec![Message { - role: Role::User, - content: MessageContent::Text("Test".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, - }], - temperature: Some(0.7), - top_p: None, - max_tokens: Some(100), - max_completion_tokens: Some(150), // New field - stream: Some(true), - stream_options: Some(StreamOptions { - include_usage: Some(true), // New field - }), - stop: None, - tools: None, - tool_choice: None, - parallel_tool_calls: None, - user: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - logprobs: None, - top_logprobs: None, - n: None, - seed: None, - response_format: None, - service_tier: Some("default".to_string()), // New field - store: Some(true), // New field - metadata: Some(HashMap::from([ - ("user_id".to_string(), "123".to_string()), - ("session_id".to_string(), "abc".to_string()), - ])), // New field - modalities: None, - function_call: None, - functions: None, - prediction: None, - }; - - // Test serialization works - let json = serde_json::to_value(&request).unwrap(); - assert!(json.as_object().unwrap().contains_key("max_completion_tokens")); - assert!(json.as_object().unwrap().contains_key("stream_options")); - assert!(json.as_object().unwrap().contains_key("service_tier")); - assert!(json.as_object().unwrap().contains_key("store")); - assert!(json.as_object().unwrap().contains_key("metadata")); - - // Test that None fields are skipped - let minimal_request = ChatCompletionsRequest { - model: "gpt-4".to_string(), - messages: vec![Message { - role: Role::User, - content: MessageContent::Text("Test".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, - }], - temperature: None, - top_p: None, - max_tokens: None, - max_completion_tokens: None, - stream: None, - stream_options: None, - stop: None, - tools: None, - tool_choice: None, - parallel_tool_calls: None, - user: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - logprobs: None, - top_logprobs: None, - n: None, - seed: None, - response_format: None, - service_tier: None, - store: None, - metadata: None, - modalities: None, - function_call: None, - functions: None, - prediction: None, - }; - - let minimal_json = serde_json::to_value(&minimal_request).unwrap(); - let obj = minimal_json.as_object().unwrap(); - - // These new fields should not be present when None - assert!(!obj.contains_key("max_completion_tokens")); - assert!(!obj.contains_key("stream_options")); - assert!(!obj.contains_key("service_tier")); - assert!(!obj.contains_key("store")); - assert!(!obj.contains_key("metadata")); - } - - #[test] - fn test_token_usage_details() { - // Test that the detailed token usage types work correctly - let usage_with_details = Usage { - prompt_tokens: 100, - completion_tokens: 50, - total_tokens: 150, - prompt_tokens_details: Some(PromptTokensDetails { - cached_tokens: Some(20), - audio_tokens: Some(10), - }), - completion_tokens_details: Some(CompletionTokensDetails { - reasoning_tokens: Some(30), - audio_tokens: Some(5), - accepted_prediction_tokens: Some(15), - rejected_prediction_tokens: Some(3), - }), - }; - - let json = serde_json::to_value(&usage_with_details).unwrap(); - let obj = json.as_object().unwrap(); - - assert!(obj.contains_key("prompt_tokens_details")); - assert!(obj.contains_key("completion_tokens_details")); - - // Test basic usage without details - let basic_usage = Usage { - prompt_tokens: 100, - completion_tokens: 50, - total_tokens: 150, - prompt_tokens_details: None, - completion_tokens_details: None, - }; - - let basic_json = serde_json::to_value(&basic_usage).unwrap(); - let basic_obj = basic_json.as_object().unwrap(); - - // None fields should be skipped - assert!(!basic_obj.contains_key("prompt_tokens_details")); - assert!(!basic_obj.contains_key("completion_tokens_details")); - } - - #[test] - fn test_message_content_serialization() { - // Test that MessageContent serializes correctly for OpenAI API - - // Test simple text content - let text_message = Message { - role: Role::User, - content: MessageContent::Text("Hello, how are you?".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, - }; - - let text_json = serde_json::to_value(&text_message).unwrap(); - let content_value = &text_json["content"]; - - // Should serialize as a simple string - assert!(content_value.is_string()); - assert_eq!(content_value.as_str().unwrap(), "Hello, how are you?"); - - // Test multimodal content with text and image - let parts_message = Message { + // Test multimodal message content + let multimodal_message = Message { role: Role::User, content: MessageContent::Parts(vec![ ContentPart::Text { @@ -873,431 +596,155 @@ mod tests { tool_call_id: None, }; - let parts_json = serde_json::to_value(&parts_message).unwrap(); - let parts_content_value = &parts_json["content"]; - - // Should serialize as an array - assert!(parts_content_value.is_array()); - let parts_array = parts_content_value.as_array().unwrap(); - assert_eq!(parts_array.len(), 2); - - // First part should be text - assert_eq!(parts_array[0]["type"], "text"); - assert_eq!(parts_array[0]["text"], "What's in this image?"); - - // Second part should be image - assert_eq!(parts_array[1]["type"], "image_url"); - assert_eq!(parts_array[1]["image_url"]["url"], "https://example.com/image.jpg"); - assert_eq!(parts_array[1]["image_url"]["detail"], "high"); - - // Test deserialization back - let deserialized_text: Message = serde_json::from_value(text_json).unwrap(); - if let MessageContent::Text(text) = deserialized_text.content { - assert_eq!(text, "Hello, how are you?"); - } else { - panic!("Expected text content"); - } - - let deserialized_parts: Message = serde_json::from_value(parts_json).unwrap(); - if let MessageContent::Parts(parts) = deserialized_parts.content { - assert_eq!(parts.len(), 2); - if let ContentPart::Text { text } = &parts[0] { - assert_eq!(text, "What's in this image?"); - } else { - panic!("Expected text part"); - } - } else { - panic!("Expected parts content"); - } - } - - #[test] - fn test_static_content_serialization() { - // Test StaticContent serialization for prediction functionality - - // Test simple text static content - let text_static = StaticContent { - content_type: "content".to_string(), - content: StaticContentType::Text("This is the predicted text output".to_string()), - }; - - let text_json = serde_json::to_value(&text_static).unwrap(); - assert_eq!(text_json["type"], "content"); - assert_eq!(text_json["content"], "This is the predicted text output"); - - // Test structured static content with parts - let parts_static = StaticContent { - content_type: "content".to_string(), - content: StaticContentType::Parts(vec![ - ContentPart::Text { - text: "First part of predicted content".to_string(), - }, - ContentPart::Text { - text: "Second part of predicted content".to_string(), - }, - ]), - }; - - let parts_json = serde_json::to_value(&parts_static).unwrap(); - assert_eq!(parts_json["type"], "content"); - assert!(parts_json["content"].is_array()); - - let content_array = parts_json["content"].as_array().unwrap(); + let multimodal_json = serde_json::to_value(&multimodal_message).unwrap(); + let content_array = multimodal_json["content"].as_array().unwrap(); assert_eq!(content_array.len(), 2); assert_eq!(content_array[0]["type"], "text"); - assert_eq!(content_array[0]["text"], "First part of predicted content"); - assert_eq!(content_array[1]["type"], "text"); - assert_eq!(content_array[1]["text"], "Second part of predicted content"); + assert_eq!(content_array[1]["type"], "image_url"); + assert_eq!(content_array[1]["image_url"]["detail"], "high"); - // Test in a ChatCompletionsRequest - let request_with_prediction = ChatCompletionsRequest { - model: "gpt-4".to_string(), - messages: vec![Message { - role: Role::User, - content: MessageContent::Text("Continue this file:".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, - }], - prediction: Some(text_static), - temperature: Some(0.1), - // ... other fields as None - top_p: None, - max_tokens: None, - max_completion_tokens: None, - stream: None, - stream_options: None, - stop: None, - tools: None, - tool_choice: None, - parallel_tool_calls: None, - user: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - logprobs: None, - top_logprobs: None, - n: None, - seed: None, - response_format: None, - service_tier: None, - store: None, - metadata: None, - modalities: None, - function_call: None, - functions: None, + // Test static content for prediction + let static_content = StaticContent { + content_type: "content".to_string(), + content: StaticContentType::Text("Predicted output".to_string()), }; - let request_json = serde_json::to_value(&request_with_prediction).unwrap(); - assert!(request_json.as_object().unwrap().contains_key("prediction")); + let static_json = serde_json::to_value(&static_content).unwrap(); + assert_eq!(static_json["type"], "content"); + assert_eq!(static_json["content"], "Predicted output"); - let prediction_value = &request_json["prediction"]; - assert_eq!(prediction_value["type"], "content"); - assert_eq!(prediction_value["content"], "This is the predicted text output"); + // Test streaming delta + let stream_delta = MessageDelta { + role: Some(Role::Assistant), + content: Some("Streaming response".to_string()), + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: 0, + id: Some("call_123".to_string()), + call_type: Some("function".to_string()), + function: Some(FunctionCallDelta { + name: Some("get_weather".to_string()), + arguments: Some(r#"{"location":"NYC"}"#.to_string()), + }), + }]), + }; - // Test deserialization - let deserialized: StaticContent = serde_json::from_value(text_json).unwrap(); - assert_eq!(deserialized.content_type, "content"); - if let StaticContentType::Text(content) = deserialized.content { - assert_eq!(content, "This is the predicted text output"); - } else { - panic!("Expected text static content"); - } + let delta_json = serde_json::to_value(&stream_delta).unwrap(); + assert_eq!(delta_json["role"], "assistant"); + assert_eq!(delta_json["content"], "Streaming response"); + assert!(delta_json["tool_calls"].as_array().unwrap().len() == 1); + assert!(!delta_json.as_object().unwrap().contains_key("refusal")); // None should be skipped } #[test] - fn test_role_specific_message_fields() { - // Test that tool_calls and tool_call_id are properly serialized/skipped based on role + fn test_api_provider_trait() { + // Test the ApiDefinition trait implementation + let api = OpenAIApi::ChatCompletions; - // User message - should not have tool_calls or tool_call_id + // Test trait methods + assert_eq!(api.endpoint(), "/v1/chat/completions"); + assert!(api.supports_streaming()); + assert!(api.supports_tools()); + assert!(api.supports_vision()); + + // Test from_endpoint + let found_api = OpenAIApi::from_endpoint("/v1/chat/completions"); + assert_eq!(found_api, Some(OpenAIApi::ChatCompletions)); + + let not_found = OpenAIApi::from_endpoint("/v1/unknown"); + assert_eq!(not_found, None); + + // Test all_variants + let all_variants = OpenAIApi::all_variants(); + assert_eq!(all_variants.len(), 1); + assert_eq!(all_variants[0], OpenAIApi::ChatCompletions); + } + + #[test] + fn test_role_specific_behavior() { + // Test role-specific serialization behavior + + // User message - basic content, no tool-related fields let user_message = Message { role: Role::User, content: MessageContent::Text("Hello!".to_string()), - name: None, + name: Some("user123".to_string()), tool_calls: None, tool_call_id: None, }; let user_json = serde_json::to_value(&user_message).unwrap(); let user_obj = user_json.as_object().unwrap(); - assert_eq!(user_obj["role"], "user"); - assert_eq!(user_obj["content"], "Hello!"); - // These should be omitted when None + assert_eq!(user_obj["name"], "user123"); assert!(!user_obj.contains_key("tool_calls")); assert!(!user_obj.contains_key("tool_call_id")); - assert!(!user_obj.contains_key("name")); // Assistant message with tool calls let assistant_message = Message { role: Role::Assistant, - content: MessageContent::Text("I'll help you with that.".to_string()), + content: MessageContent::Text("I'll help with that.".to_string()), name: None, - tool_calls: Some(vec![ToolCall { - id: "call_123".to_string(), - call_type: "function".to_string(), - function: FunctionCall { - name: "get_weather".to_string(), - arguments: r#"{"location": "San Francisco"}"#.to_string(), - }, - }]), - tool_call_id: None, - }; - - let assistant_json = serde_json::to_value(&assistant_message).unwrap(); - let assistant_obj = assistant_json.as_object().unwrap(); - - assert_eq!(assistant_obj["role"], "assistant"); - assert_eq!(assistant_obj["content"], "I'll help you with that."); - assert!(assistant_obj.contains_key("tool_calls")); - assert!(!assistant_obj.contains_key("tool_call_id")); // Should be omitted for assistant - - let tool_calls = assistant_obj["tool_calls"].as_array().unwrap(); - assert_eq!(tool_calls.len(), 1); - assert_eq!(tool_calls[0]["id"], "call_123"); - - // Tool message responding to a tool call - let tool_message = Message { - role: Role::Tool, - content: MessageContent::Text("The weather in San Francisco is sunny, 72°F".to_string()), - name: None, - tool_calls: None, - tool_call_id: Some("call_123".to_string()), - }; - - let tool_json = serde_json::to_value(&tool_message).unwrap(); - let tool_obj = tool_json.as_object().unwrap(); - - assert_eq!(tool_obj["role"], "tool"); - assert_eq!(tool_obj["content"], "The weather in San Francisco is sunny, 72°F"); - assert_eq!(tool_obj["tool_call_id"], "call_123"); - assert!(!tool_obj.contains_key("tool_calls")); // Should be omitted for tool messages - - // Test deserialization - let deserialized_tool: Message = serde_json::from_value(tool_json).unwrap(); - assert_eq!(deserialized_tool.role, Role::Tool); - assert_eq!(deserialized_tool.tool_call_id, Some("call_123".to_string())); - assert!(deserialized_tool.tool_calls.is_none()); - } - - #[test] - fn test_request_vs_response_messages() { - // Test the difference between request Message and response ResponseMessage - - // Request message (used in ChatCompletionsRequest) - let request_message = Message { - role: Role::User, - content: MessageContent::Text("What's the weather like?".to_string()), - name: Some("user123".to_string()), - tool_calls: None, - tool_call_id: None, - }; - - let request_json = serde_json::to_value(&request_message).unwrap(); - let request_obj = request_json.as_object().unwrap(); - - // Request messages have complex content (MessageContent enum) - assert_eq!(request_obj["role"], "user"); - assert_eq!(request_obj["content"], "What's the weather like?"); - assert_eq!(request_obj["name"], "user123"); - assert!(!request_obj.contains_key("refusal")); - assert!(!request_obj.contains_key("annotations")); - assert!(!request_obj.contains_key("audio")); - - // Response message (used in ChatCompletionsResponse) - let response_message = ResponseMessage { - role: Role::Assistant, - content: Some("It's sunny and 75°F in San Francisco.".to_string()), - refusal: None, - annotations: Some(vec![serde_json::json!({"type": "web_search", "query": "San Francisco weather"})]), - audio: None, - function_call: None, tool_calls: Some(vec![ToolCall { id: "call_456".to_string(), call_type: "function".to_string(), function: FunctionCall { name: "get_weather".to_string(), - arguments: r#"{"location": "San Francisco"}"#.to_string(), + arguments: r#"{"location":"SF"}"#.to_string(), }, }]), + tool_call_id: None, // Should not be present for assistant }; - let response_json = serde_json::to_value(&response_message).unwrap(); - let response_obj = response_json.as_object().unwrap(); + let assistant_json = serde_json::to_value(&assistant_message).unwrap(); + let assistant_obj = assistant_json.as_object().unwrap(); + assert_eq!(assistant_obj["role"], "assistant"); + assert!(assistant_obj.contains_key("tool_calls")); + assert!(!assistant_obj.contains_key("tool_call_id")); // Not for assistant + assert!(!assistant_obj.contains_key("name")); // None, so skipped - // Response messages have simple string content and additional fields - assert_eq!(response_obj["role"], "assistant"); - assert_eq!(response_obj["content"], "It's sunny and 75°F in San Francisco."); - assert!(response_obj.contains_key("annotations")); - assert!(response_obj.contains_key("tool_calls")); - assert!(!response_obj.contains_key("name")); // Response messages don't have names - assert!(!response_obj.contains_key("tool_call_id")); // Only for tool messages in requests + // Tool message responding to a call + let tool_message = Message { + role: Role::Tool, + content: MessageContent::Text("Weather is sunny".to_string()), + name: None, + tool_calls: None, // Should not be present for tool messages + tool_call_id: Some("call_456".to_string()), + }; - // Test conversion from ResponseMessage to Message - let converted_message = response_message.to_message(); - assert_eq!(converted_message.role, Role::Assistant); - if let MessageContent::Text(text) = converted_message.content { - assert_eq!(text, "It's sunny and 75°F in San Francisco."); - } else { - panic!("Expected text content"); - } - assert_eq!(converted_message.tool_calls, response_message.tool_calls); + let tool_json = serde_json::to_value(&tool_message).unwrap(); + let tool_obj = tool_json.as_object().unwrap(); + assert_eq!(tool_obj["role"], "tool"); + assert_eq!(tool_obj["tool_call_id"], "call_456"); + assert!(!tool_obj.contains_key("tool_calls")); // Not for tool messages - // Test response message with refusal - let refusal_message = ResponseMessage { + // Test ResponseMessage vs Message differences + let response_message = ResponseMessage { role: Role::Assistant, - content: None, - refusal: Some("I cannot provide information about that topic.".to_string()), - annotations: None, + content: Some("Response content".to_string()), + refusal: None, + annotations: Some(vec![json!({"type": "citation"})]), audio: None, function_call: None, tool_calls: None, }; - let refusal_json = serde_json::to_value(&refusal_message).unwrap(); - let refusal_obj = refusal_json.as_object().unwrap(); + let response_json = serde_json::to_value(&response_message).unwrap(); + let response_obj = response_json.as_object().unwrap(); - assert_eq!(refusal_obj["role"], "assistant"); - // Content is None and gets skipped by #[skip_serializing_none] - assert!(!refusal_obj.contains_key("content")); - // Check if refusal field exists and has the expected value - if let Some(refusal_value) = refusal_obj.get("refusal") { - assert_eq!(refusal_value, "I cannot provide information about that topic."); + // ResponseMessage has different fields than Message + assert!(response_obj.contains_key("annotations")); + assert!(!response_obj.contains_key("name")); // Not in ResponseMessage + assert!(!response_obj.contains_key("tool_call_id")); // Not in ResponseMessage + + // Test conversion + let converted = response_message.to_message(); + assert_eq!(converted.role, Role::Assistant); + if let MessageContent::Text(text) = converted.content { + assert_eq!(text, "Response content"); } else { - panic!("Expected refusal field to be present"); + panic!("Expected text content"); } } - - #[test] - fn test_streaming_types_completeness() { - // Test that streaming types include all fields from OpenAI API spec - - // Test ChatCompletionsStreamResponse with all fields - let stream_response = ChatCompletionsStreamResponse { - id: "chatcmpl-stream-456".to_string(), - object: "chat.completion.chunk".to_string(), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![StreamChoice { - index: 0, - delta: MessageDelta { - role: Some(Role::Assistant), - content: Some("Hello there!".to_string()), - refusal: None, - function_call: Some(FunctionCall { - name: "get_weather".to_string(), - arguments: r#"{"location": "NYC"}"#.to_string(), - }), - tool_calls: Some(vec![ToolCallDelta { - index: 0, - id: Some("call_789".to_string()), - call_type: Some("function".to_string()), - function: Some(FunctionCallDelta { - name: Some("get_temperature".to_string()), - arguments: Some(r#"{"city": "Boston"}"#.to_string()), - }), - }]), - }, - finish_reason: Some(FinishReason::ToolCalls), - logprobs: None, - }], - usage: Some(Usage { - prompt_tokens: 25, - completion_tokens: 15, - total_tokens: 40, - prompt_tokens_details: None, - completion_tokens_details: None, - }), - system_fingerprint: Some("fp_12345".to_string()), - service_tier: Some("default".to_string()), - }; - - let json = serde_json::to_value(&stream_response).unwrap(); - let obj = json.as_object().unwrap(); - - // Test all top-level fields are present - assert_eq!(obj["id"], "chatcmpl-stream-456"); - assert_eq!(obj["object"], "chat.completion.chunk"); - assert_eq!(obj["created"], 1234567890); - assert_eq!(obj["model"], "gpt-4"); - assert_eq!(obj["system_fingerprint"], "fp_12345"); - assert_eq!(obj["service_tier"], "default"); - assert!(obj.contains_key("choices")); - assert!(obj.contains_key("usage")); - - // Test choice structure - let choices = obj["choices"].as_array().unwrap(); - assert_eq!(choices.len(), 1); - let choice = &choices[0]; - assert_eq!(choice["index"], 0); - assert_eq!(choice["finish_reason"], "tool_calls"); - - // Test delta structure with all fields - let delta = &choice["delta"]; - assert_eq!(delta["role"], "assistant"); - assert_eq!(delta["content"], "Hello there!"); - assert!(delta.as_object().unwrap().contains_key("function_call")); - assert!(delta.as_object().unwrap().contains_key("tool_calls")); - // refusal is None so should be skipped - assert!(!delta.as_object().unwrap().contains_key("refusal")); - - // Test tool call delta structure - let tool_calls = delta["tool_calls"].as_array().unwrap(); - assert_eq!(tool_calls.len(), 1); - let tool_call = &tool_calls[0]; - assert_eq!(tool_call["index"], 0); - assert_eq!(tool_call["id"], "call_789"); - assert_eq!(tool_call["type"], "function"); - - let function = &tool_call["function"]; - assert_eq!(function["name"], "get_temperature"); - assert_eq!(function["arguments"], r#"{"city": "Boston"}"#); - - // Test minimal streaming response (like final chunk) - let minimal_response = ChatCompletionsStreamResponse { - id: "chatcmpl-stream-456".to_string(), - object: "chat.completion.chunk".to_string(), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![], // Can be empty for final chunk with usage - usage: Some(Usage { - prompt_tokens: 25, - completion_tokens: 15, - total_tokens: 40, - prompt_tokens_details: None, - completion_tokens_details: None, - }), - system_fingerprint: None, - service_tier: None, - }; - - let minimal_json = serde_json::to_value(&minimal_response).unwrap(); - let minimal_obj = minimal_json.as_object().unwrap(); - - // Empty choices array should be present - assert_eq!(minimal_obj["choices"].as_array().unwrap().len(), 0); - assert!(minimal_obj.contains_key("usage")); - // None fields should be skipped - assert!(!minimal_obj.contains_key("system_fingerprint")); - assert!(!minimal_obj.contains_key("service_tier")); - - // Test delta with refusal - let refusal_delta = MessageDelta { - role: Some(Role::Assistant), - content: None, - refusal: Some("I can't help with that request.".to_string()), - function_call: None, - tool_calls: None, - }; - - let refusal_json = serde_json::to_value(&refusal_delta).unwrap(); - let refusal_obj = refusal_json.as_object().unwrap(); - - assert_eq!(refusal_obj["role"], "assistant"); - assert_eq!(refusal_obj["refusal"], "I can't help with that request."); - // None fields should be skipped - assert!(!refusal_obj.contains_key("content")); - assert!(!refusal_obj.contains_key("function_call")); - assert!(!refusal_obj.contains_key("tool_calls")); - } }