plano/crates/hermesllm/src/apis/openai.rs
2025-12-11 13:53:44 -08:00

1352 lines
46 KiB
Rust

use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use std::fmt::Display;
use thiserror::Error;
use super::ApiDefinition;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, TokenUsage};
use crate::providers::streaming_response::ProviderStreamResponse;
use crate::transforms::lib::ExtractText;
use crate::{CHAT_COMPLETIONS_PATH, OPENAI_RESPONSES_API_PATH};
// ============================================================================
// OPENAI API ENUMERATION
// ============================================================================
/// Enum for all supported OpenAI APIs
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum OpenAIApi {
ChatCompletions,
Responses,
// Future APIs can be added here:
// Embeddings,
// FineTuning,
// etc.
}
impl ApiDefinition for OpenAIApi {
fn endpoint(&self) -> &'static str {
match self {
OpenAIApi::ChatCompletions => CHAT_COMPLETIONS_PATH,
OpenAIApi::Responses => OPENAI_RESPONSES_API_PATH,
}
}
fn from_endpoint(endpoint: &str) -> Option<Self> {
match endpoint {
CHAT_COMPLETIONS_PATH => Some(OpenAIApi::ChatCompletions),
OPENAI_RESPONSES_API_PATH => Some(OpenAIApi::Responses),
_ => None,
}
}
fn supports_streaming(&self) -> bool {
match self {
OpenAIApi::ChatCompletions => true,
OpenAIApi::Responses => true,
}
}
fn supports_tools(&self) -> bool {
match self {
OpenAIApi::ChatCompletions => true,
OpenAIApi::Responses => true,
}
}
fn supports_vision(&self) -> bool {
match self {
OpenAIApi::ChatCompletions => true,
OpenAIApi::Responses => true,
}
}
fn all_variants() -> Vec<Self> {
vec![OpenAIApi::ChatCompletions, OpenAIApi::Responses]
}
}
/// Chat completions API request
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct ChatCompletionsRequest {
pub messages: Vec<Message>,
pub model: String,
// pub audio: Option<Audio> // GOOD FIRST ISSUE: future support for audio input
pub frequency_penalty: Option<f32>,
// Function calling configuration has been deprecated, but we keep it for compatibility
pub function_call: Option<FunctionChoice>,
pub functions: Option<Vec<Tool>>,
pub logit_bias: Option<HashMap<String, i32>>,
pub logprobs: Option<bool>,
pub max_completion_tokens: Option<u32>,
// Maximum tokens in the response has been deprecated, but we keep it for compatibility
pub max_tokens: Option<u32>,
pub modalities: Option<Vec<String>>,
pub metadata: Option<HashMap<String, Value>>,
pub n: Option<u32>,
pub presence_penalty: Option<f32>,
pub parallel_tool_calls: Option<bool>,
pub prediction: Option<StaticContent>,
// pub reasoning_effect: Option<bool>, // GOOD FIRST ISSUE: Future support for reasoning effects
pub response_format: Option<Value>,
pub reasoning_effort: Option<String>, // e.g., "none", "low", "medium", "high"
// pub safety_identifier: Option<String>, // GOOD FIRST ISSUE: Future support for safety identifiers
pub seed: Option<i32>,
pub service_tier: Option<String>,
pub stop: Option<Vec<String>>,
pub store: Option<bool>,
pub stream: Option<bool>,
pub stream_options: Option<StreamOptions>,
pub temperature: Option<f32>,
pub tool_choice: Option<ToolChoice>,
pub tools: Option<Vec<Tool>>,
pub top_p: Option<f32>,
pub top_logprobs: Option<u32>,
pub user: Option<String>,
// pub web_search: Option<bool>, // GOOD FIRST ISSUE: Future support for web search
// VLLM-specific parameters (used by Arch-Function)
pub top_k: Option<u32>,
pub stop_token_ids: Option<Vec<u32>>,
pub continue_final_message: Option<bool>,
pub add_generation_prompt: Option<bool>,
}
impl ChatCompletionsRequest {
/// Suppress max_tokens if the model is o3, o3-*, openrouter/o3, or openrouter/o3-*
pub fn suppress_max_tokens_if_o3(&mut self) {
let model = self.model.as_str();
let is_o3 = model == "o3"
|| model.starts_with("o3-")
|| model == "openrouter/o3"
|| model.starts_with("openrouter/o3-");
if is_o3 {
self.max_tokens = None;
}
}
pub fn fix_temperature_if_gpt5(&mut self) {
let model = self.model.as_str();
if model.starts_with("gpt-5") {
self.temperature = Some(1.0);
}
}
}
// ============================================================================
// CHAT COMPLETIONS API TYPES
// ============================================================================
/// Message role in a chat conversation
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Message {
pub role: Role,
pub content: MessageContent,
pub name: Option<String>,
/// Tool calls made by the assistant (only present for assistant role)
pub tool_calls: Option<Vec<ToolCall>>,
/// ID of the tool call that this message is responding to (only present for tool role)
pub tool_call_id: Option<String>,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ResponseMessage {
pub role: Role,
/// The contents of the message (can be null for some cases)
pub content: Option<String>,
/// The refusal message generated by the model
pub refusal: Option<String>,
/// Annotations for the message, when applicable, as when using the web search tool
pub annotations: Option<Vec<Value>>,
/// If the audio output modality is requested, this object contains data about the audio response
pub audio: Option<Value>,
/// Deprecated and replaced by tool_calls. The name and arguments of a function that should be called
pub function_call: Option<FunctionCall>,
/// The tool calls generated by the model, such as function calls
pub tool_calls: Option<Vec<ToolCall>>,
}
impl Default for ResponseMessage {
fn default() -> Self {
ResponseMessage {
role: Role::Assistant,
content: None,
refusal: None,
annotations: None,
audio: None,
function_call: None,
tool_calls: None,
}
}
}
impl ResponseMessage {
/// Convert ResponseMessage to Message for internal processing
/// This is useful for transformations that need to work with the request Message type
pub fn to_message(&self) -> Message {
Message {
role: self.role.clone(),
content: self
.content
.as_ref()
.map(|s| MessageContent::Text(s.clone()))
.unwrap_or(MessageContent::Text(String::new())),
name: None, // Response messages don't have names in the same way request messages do
tool_calls: self.tool_calls.clone(),
tool_call_id: None, // Response messages don't have tool_call_id
}
}
}
/// In the OpenAI API, this is represented as either:
/// - A string for simple text content
/// - An array of content parts for multimodal content (text + images)
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
// Content Extraction
impl ExtractText for MessageContent {
fn extract_text(&self) -> String {
match self {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => parts.extract_text(),
}
}
}
impl ExtractText for Vec<ContentPart> {
fn extract_text(&self) -> String {
self.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n")
}
}
impl Display for MessageContent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MessageContent::Text(text) => write!(f, "{}", text),
MessageContent::Parts(parts) => {
let text_parts: Vec<String> = parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
ContentPart::ImageUrl { .. } => {
// skip image URLs or their data in text representation
None
}
})
.collect();
let combined_text = text_parts.join("\n");
write!(f, "{}", combined_text)
}
}
}
}
/// Individual content part within a message (text or image)
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
/// Image URL configuration for vision capabilities
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ImageUrl {
pub url: String,
pub detail: Option<String>,
}
/// A single message in a chat conversation
/// A tool call made by the assistant
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
/// Function call within a tool call
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
/// Tool definition for function calling
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: Function,
}
/// Function definition within a tool
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Function {
pub name: String,
pub description: Option<String>,
pub parameters: Value,
pub strict: Option<bool>,
}
/// Tool choice string values
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ToolChoiceType {
/// Let the model automatically decide whether to call tools
Auto,
/// Force the model to call at least one tool
Required,
/// Prevent the model from calling any tools
None,
}
/// Tool choice configuration
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(untagged)]
pub enum ToolChoice {
/// String-based tool choice (auto, required, none)
Type(ToolChoiceType),
/// Specific function to call
Function {
#[serde(rename = "type")]
choice_type: String,
function: FunctionChoice,
},
}
/// Specific function choice
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct FunctionChoice {
pub name: String,
}
/// Static content for prediction/prefill functionality
///
/// Static predicted output content, such as the content of a text file
/// that is being regenerated.
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StaticContent {
/// The type of the predicted content you want to provide.
/// This type is currently always "content".
#[serde(rename = "type")]
pub content_type: String,
/// The content that should be matched when generating a model response.
/// If generated tokens would match this content, the entire model response
/// can be returned much more quickly.
///
/// Can be either:
/// - A string for simple text content
/// - An array of content parts for structured content
pub content: StaticContentType,
}
/// Content type for static/predicted content
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum StaticContentType {
/// Simple text content - the content used for a Predicted Output.
/// This is often the text of a file you are regenerating with minor changes.
Text(String),
/// An array of content parts with a defined type.
/// Can contain text inputs and other supported content types.
Parts(Vec<ContentPart>),
}
/// Chat completions API response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChatCompletionsResponse {
pub id: String,
pub object: Option<String>,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
pub system_fingerprint: Option<String>,
pub service_tier: Option<String>,
// This isn't a standard OpenAI field, but we include it for extensibility
pub metadata: Option<HashMap<String, Value>>,
}
impl Default for ChatCompletionsResponse {
fn default() -> Self {
ChatCompletionsResponse {
id: String::new(),
object: None,
created: 0,
model: String::new(),
choices: vec![],
usage: Usage::default(),
system_fingerprint: None,
service_tier: None,
metadata: None,
}
}
}
/// Finish reason for completion
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ToolCalls,
ContentFilter,
FunctionCall, // Legacy
}
/// Token usage information
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub prompt_tokens_details: Option<PromptTokensDetails>,
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
impl Default for Usage {
fn default() -> Self {
Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: None,
completion_tokens_details: None,
}
}
}
/// Detailed breakdown of prompt tokens
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PromptTokensDetails {
pub cached_tokens: Option<u32>,
pub audio_tokens: Option<u32>,
}
/// Detailed breakdown of completion tokens
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CompletionTokensDetails {
pub reasoning_tokens: Option<u32>,
pub audio_tokens: Option<u32>,
pub accepted_prediction_tokens: Option<u32>,
pub rejected_prediction_tokens: Option<u32>,
}
/// A single choice in the response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Choice {
pub index: u32,
pub message: ResponseMessage,
pub finish_reason: Option<FinishReason>,
pub logprobs: Option<Value>,
}
impl Default for Choice {
fn default() -> Self {
Choice {
index: 0,
message: ResponseMessage::default(),
finish_reason: None,
logprobs: None,
}
}
}
// ============================================================================
// STREAMING API TYPES
// ============================================================================
/// Streaming response from chat completions API
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChatCompletionsStreamResponse {
pub id: String,
pub object: Option<String>,
pub created: u64,
pub model: String,
pub choices: Vec<StreamChoice>,
pub usage: Option<Usage>, // Only in final chunk
pub system_fingerprint: Option<String>,
/// Specifies the processing type used for serving the request
pub service_tier: Option<String>,
}
/// A choice in a streaming response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamChoice {
pub index: u32,
pub delta: MessageDelta,
pub finish_reason: Option<FinishReason>,
pub logprobs: Option<Value>,
}
/// Message delta for streaming updates
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessageDelta {
pub role: Option<Role>,
pub content: Option<String>,
/// The refusal message generated by the model
pub refusal: Option<String>,
/// Deprecated and replaced by tool_calls. The name and arguments of a function that should be called
pub function_call: Option<FunctionCall>,
pub tool_calls: Option<Vec<ToolCallDelta>>,
}
/// Tool call delta for streaming tool call updates
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ToolCallDelta {
pub index: u32,
pub id: Option<String>,
#[serde(rename = "type")]
pub call_type: Option<String>,
pub function: Option<FunctionCallDelta>,
}
/// Function call delta for streaming function call updates
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct FunctionCallDelta {
pub name: Option<String>,
pub arguments: Option<String>,
}
/// Stream options for controlling streaming behavior
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamOptions {
pub include_usage: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelDetail {
pub id: String,
pub object: Option<String>,
pub created: usize,
pub owned_by: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelObject {
#[serde(rename = "list")]
List,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Models {
pub object: ModelObject,
pub data: Vec<ModelDetail>,
}
// Error type for streaming operations
#[derive(Debug, thiserror::Error)]
pub enum OpenAIStreamError {
#[error("JSON parsing error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("UTF-8 parsing error: {0}")]
Utf8Error(#[from] std::str::Utf8Error),
#[error("Invalid streaming data: {0}")]
InvalidStreamingData(String),
}
#[derive(Debug, Error)]
pub enum OpenAIError {
#[error("json error: {0}")]
JsonParseError(#[from] serde_json::Error),
#[error("utf8 parsing error: {0}")]
Utf8Error(#[from] std::str::Utf8Error),
#[error("invalid streaming data err {source}, data: {data}")]
InvalidStreamingData {
source: serde_json::Error,
data: String,
},
#[error("unsupported provider: {provider}")]
UnsupportedProvider { provider: String },
}
// ============================================================================
/// Trait Implementations
/// ===========================================================================
/// Parameterized conversion for ChatCompletionsRequest
impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIStreamError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
let mut req: ChatCompletionsRequest =
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
// Use the centralized suppression logic
req.suppress_max_tokens_if_o3();
req.fix_temperature_if_gpt5();
Ok(req)
}
}
/// Parameterized conversion for ChatCompletionsResponse
impl TryFrom<&[u8]> for ChatCompletionsResponse {
type Error = OpenAIStreamError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
}
}
/// Implementation of TokenUsage for OpenAI Usage type
impl TokenUsage for Usage {
fn completion_tokens(&self) -> usize {
self.completion_tokens as usize
}
fn prompt_tokens(&self) -> usize {
self.prompt_tokens as usize
}
fn total_tokens(&self) -> usize {
self.total_tokens as usize
}
}
/// Implementation of ProviderRequest for ChatCompletionsRequest
impl ProviderRequest for ChatCompletionsRequest {
fn model(&self) -> &str {
&self.model
}
fn set_model(&mut self, model: String) {
self.model = model;
}
fn is_streaming(&self) -> bool {
self.stream.unwrap_or_default()
}
fn extract_messages_text(&self) -> String {
self.messages.iter().fold(String::new(), |acc, m| {
acc + " "
+ &match &m.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => parts
.iter()
.map(|part| match part {
ContentPart::Text { text } => text.clone(),
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
})
.collect::<Vec<_>>()
.join(" "),
}
})
}
fn get_recent_user_message(&self) -> Option<String> {
self.messages.last().and_then(|msg| {
match &msg.content {
MessageContent::Text(text) => Some(text.clone()),
MessageContent::Parts(_) => None, // No user message in parts
}
})
}
fn get_tool_names(&self) -> Option<Vec<String>> {
// First check the 'tools' field (current API)
if let Some(tools) = &self.tools {
let names: Vec<String> = tools
.iter()
.map(|tool| tool.function.name.clone())
.collect();
if !names.is_empty() {
return Some(names);
}
}
// Fallback to 'functions' field (deprecated but still supported)
if let Some(functions) = &self.functions {
let names: Vec<String> = functions
.iter()
.map(|func| func.function.name.clone())
.collect();
if !names.is_empty() {
return Some(names);
}
}
None
}
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
serde_json::to_vec(&self).map_err(|e| ProviderRequestError {
message: format!("Failed to serialize OpenAI request: {}", e),
source: Some(Box::new(e)),
})
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
return &self.metadata;
}
fn remove_metadata_key(&mut self, key: &str) -> bool {
if let Some(ref mut metadata) = self.metadata {
metadata.remove(key).is_some()
} else {
false
}
}
fn get_temperature(&self) -> Option<f32> {
self.temperature
}
}
/// Implementation of ProviderResponse for ChatCompletionsResponse
impl ProviderResponse for ChatCompletionsResponse {
fn usage(&self) -> Option<&dyn TokenUsage> {
Some(&self.usage)
}
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
Some((
self.usage.prompt_tokens(),
self.usage.completion_tokens(),
self.usage.total_tokens(),
))
}
}
// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse
impl ProviderStreamResponse for ChatCompletionsStreamResponse {
fn content_delta(&self) -> Option<&str> {
self.choices
.first()
.and_then(|choice| choice.delta.content.as_deref())
}
fn is_final(&self) -> bool {
self.choices
.first()
.map(|choice| choice.finish_reason.is_some())
.unwrap_or(false)
}
fn role(&self) -> Option<&str> {
self.choices.first().and_then(|choice| {
choice.delta.role.as_ref().map(|r| match r {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
})
})
}
fn event_type(&self) -> Option<&str> {
None // OpenAI doesn't use event types in SSE
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_required_fields() {
// Create a JSON object with only required fields
let original_json = json!({
"model": "gpt-4",
"messages": [
{
"content": "Hello, world!",
"role": "user"
}
]
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields are properly set
assert_eq!(deserialized_request.model, "gpt-4");
assert_eq!(deserialized_request.messages.len(), 1);
let message = &deserialized_request.messages[0];
assert_eq!(message.role, Role::User);
if let MessageContent::Text(content) = &message.content {
assert_eq!(content, "Hello, world!");
} else {
panic!("Expected text content");
}
// Serialize the ChatCompletionsRequest back to JSON
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
assert_eq!(original_json, serialized_json);
}
#[test]
fn test_optional_fields_serialization() {
// Create a JSON object with optional fields set
let original_json = json!({
"model": "gpt-4",
"messages": [
{
"content": "Test message",
"role": "user",
"name": "test_user"
}
],
"temperature": 0.7,
"max_tokens": 150,
"stream": true,
"stream_options": {
"include_usage": true
},
"metadata": {
"user_id": "123"
}
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields
assert_eq!(deserialized_request.model, "gpt-4");
assert_eq!(deserialized_request.messages.len(), 1);
let message = &deserialized_request.messages[0];
assert_eq!(message.role, Role::User);
if let MessageContent::Text(content) = &message.content {
assert_eq!(content, "Test message");
} else {
panic!("Expected text content");
}
assert_eq!(message.name, Some("test_user".to_string()));
// Validate optional fields are properly set
assert!((deserialized_request.temperature.unwrap() - 0.7).abs() < 1e-6);
assert_eq!(deserialized_request.max_tokens, Some(150));
assert_eq!(deserialized_request.stream, Some(true));
assert!(deserialized_request.stream_options.is_some());
assert!(deserialized_request.metadata.is_some());
// Validate fields not in JSON are None
assert!(deserialized_request.top_p.is_none());
assert!(deserialized_request.frequency_penalty.is_none());
assert!(deserialized_request.presence_penalty.is_none());
assert!(deserialized_request.stop.is_none());
assert!(deserialized_request.tools.is_none());
// Serialize back to JSON and compare (handle floating point precision)
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
// Compare all fields except temperature which needs floating point comparison
assert_eq!(serialized_json["model"], original_json["model"]);
assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["stream"], original_json["stream"]);
assert_eq!(
serialized_json["stream_options"],
original_json["stream_options"]
);
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
// Handle temperature with floating point tolerance
let original_temp = original_json["temperature"].as_f64().unwrap();
let serialized_temp = serialized_json["temperature"].as_f64().unwrap();
assert!((original_temp - serialized_temp).abs() < 1e-6);
}
#[test]
fn test_nested_types_serialization() {
// Create a comprehensive JSON object with nested types - a ChatCompletionsRequest with complex message content and tools
let original_json = json!({
"model": "gpt-4-vision-preview",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What can you see in this image and what's the weather like in the location shown?"
},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/cityscape.jpg",
"detail": "high"
}
}
]
},
{
"role": "assistant",
"content": "I can see a beautiful cityscape. Let me check the weather for you.",
"tool_calls": [
{
"id": "call_weather123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"location\": \"New York, NY\"}"
}
}
]
},
{
"role": "tool",
"content": "Current weather in New York: 72°F, sunny",
"tool_call_id": "call_weather123"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather information for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
}
},
"required": ["location"]
},
"strict": true
}
}
],
"tool_choice": "auto",
"temperature": 0.7,
"max_tokens": 1000,
"prediction": {
"type": "content",
"content": "Based on the image analysis and weather data, I can provide you with comprehensive information."
}
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate top-level fields
assert_eq!(deserialized_request.model, "gpt-4-vision-preview");
assert_eq!(deserialized_request.messages.len(), 3);
assert!((deserialized_request.temperature.unwrap() - 0.7).abs() < 1e-6);
assert_eq!(deserialized_request.max_tokens, Some(1000));
// Validate first message (user with multimodal content)
let user_message = &deserialized_request.messages[0];
assert_eq!(user_message.role, Role::User);
if let MessageContent::Parts(ref content_parts) = user_message.content {
assert_eq!(content_parts.len(), 2);
// Validate text content part
if let ContentPart::Text { text } = &content_parts[0] {
assert_eq!(text, "What can you see in this image and what's the weather like in the location shown?");
} else {
panic!("Expected text content part");
}
// Validate image URL content part
if let ContentPart::ImageUrl { ref image_url } = content_parts[1] {
assert_eq!(image_url.url, "https://example.com/cityscape.jpg");
assert_eq!(image_url.detail, Some("high".to_string()));
} else {
panic!("Expected image URL content part");
}
} else {
panic!("Expected multimodal content parts for user message");
}
// Validate second message (assistant with tool calls)
let assistant_message = &deserialized_request.messages[1];
assert_eq!(assistant_message.role, Role::Assistant);
if let MessageContent::Text(text) = &assistant_message.content {
assert_eq!(
text,
"I can see a beautiful cityscape. Let me check the weather for you."
);
} else {
panic!("Expected text content for assistant message");
}
// Validate tool calls in assistant message
assert!(assistant_message.tool_calls.is_some());
let tool_calls = assistant_message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
let tool_call = &tool_calls[0];
assert_eq!(tool_call.id, "call_weather123");
assert_eq!(tool_call.call_type, "function");
assert_eq!(tool_call.function.name, "get_weather");
assert_eq!(
tool_call.function.arguments,
"{\"location\": \"New York, NY\"}"
);
// Validate third message (tool response)
let tool_message = &deserialized_request.messages[2];
assert_eq!(tool_message.role, Role::Tool);
if let MessageContent::Text(text) = &tool_message.content {
assert_eq!(text, "Current weather in New York: 72°F, sunny");
} else {
panic!("Expected text content for tool message");
}
assert_eq!(
tool_message.tool_call_id,
Some("call_weather123".to_string())
);
// Validate tools array
assert!(deserialized_request.tools.is_some());
let tools = deserialized_request.tools.as_ref().unwrap();
assert_eq!(tools.len(), 1);
let tool = &tools[0];
assert_eq!(tool.tool_type, "function");
assert_eq!(tool.function.name, "get_weather");
assert_eq!(
tool.function.description,
Some("Get current weather information for a location".to_string())
);
assert_eq!(tool.function.strict, Some(true));
// Validate tool parameters schema
let parameters = &tool.function.parameters;
assert_eq!(parameters["type"], "object");
assert!(parameters["properties"]["location"].is_object());
assert_eq!(parameters["required"], json!(["location"]));
// Validate tool choice
if let Some(ToolChoice::Type(choice)) = &deserialized_request.tool_choice {
assert_eq!(choice, &ToolChoiceType::Auto);
} else {
panic!("Expected auto tool choice");
}
// Validate prediction
assert!(deserialized_request.prediction.is_some());
let prediction = deserialized_request.prediction.as_ref().unwrap();
assert_eq!(prediction.content_type, "content");
if let StaticContentType::Text(text) = &prediction.content {
assert_eq!(text, "Based on the image analysis and weather data, I can provide you with comprehensive information.");
} else {
panic!("Expected text prediction content");
}
// Serialize back to JSON and compare (handle floating point precision)
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
// Compare all fields except floating point ones
assert_eq!(serialized_json["model"], original_json["model"]);
assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["tools"], original_json["tools"]);
assert_eq!(serialized_json["tool_choice"], original_json["tool_choice"]);
assert_eq!(serialized_json["prediction"], original_json["prediction"]);
// Handle floating point field with tolerance
let original_temp = original_json["temperature"].as_f64().unwrap();
let serialized_temp = serialized_json["temperature"].as_f64().unwrap();
assert!((original_temp - serialized_temp).abs() < 1e-6);
}
#[test]
fn test_api_provider_trait() {
// Test the ApiDefinition trait implementation
let api = OpenAIApi::ChatCompletions;
// Test trait methods
assert_eq!(api.endpoint(), CHAT_COMPLETIONS_PATH);
assert!(api.supports_streaming());
assert!(api.supports_tools());
assert!(api.supports_vision());
// Test from_endpoint
let found_api = OpenAIApi::from_endpoint(CHAT_COMPLETIONS_PATH);
assert_eq!(found_api, Some(OpenAIApi::ChatCompletions));
let not_found = OpenAIApi::from_endpoint("/v1/unknown");
assert_eq!(not_found, None);
// Test all_variants
let all_variants = OpenAIApi::all_variants();
assert_eq!(all_variants.len(), 2);
assert!(all_variants.contains(&OpenAIApi::ChatCompletions));
assert!(all_variants.contains(&OpenAIApi::Responses));
}
#[test]
fn test_role_specific_behavior() {
// Test 1: User message - basic content, no tool-related fields
let user_json = json!({
"content": "Hello!",
"role": "user",
"name": "user123"
});
let deserialized_user: Message = serde_json::from_value(user_json.clone()).unwrap();
assert_eq!(deserialized_user.role, Role::User);
if let MessageContent::Text(content) = &deserialized_user.content {
assert_eq!(content, "Hello!");
} else {
panic!("Expected text content");
}
assert_eq!(deserialized_user.name, Some("user123".to_string()));
assert!(deserialized_user.tool_calls.is_none());
assert!(deserialized_user.tool_call_id.is_none());
let serialized_user_json = serde_json::to_value(&deserialized_user).unwrap();
assert_eq!(user_json, serialized_user_json);
// Test 2: Assistant message with tool calls
let assistant_json = json!({
"content": "I'll help with that.",
"role": "assistant",
"tool_calls": [
{
"id": "call_456",
"type": "function",
"function": {
"name": "get_weather",
"arguments": r#"{"location":"SF"}"#
}
}
]
});
let deserialized_assistant: Message =
serde_json::from_value(assistant_json.clone()).unwrap();
assert_eq!(deserialized_assistant.role, Role::Assistant);
if let MessageContent::Text(content) = &deserialized_assistant.content {
assert_eq!(content, "I'll help with that.");
} else {
panic!("Expected text content");
}
assert!(deserialized_assistant.tool_calls.is_some());
assert!(deserialized_assistant.tool_call_id.is_none());
assert!(deserialized_assistant.name.is_none());
let tool_calls = deserialized_assistant.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "call_456");
assert_eq!(tool_calls[0].function.name, "get_weather");
let serialized_assistant_json = serde_json::to_value(&deserialized_assistant).unwrap();
assert_eq!(assistant_json, serialized_assistant_json);
// Test 3: Tool message responding to a call
let tool_json = json!({
"content": "Weather is sunny",
"role": "tool",
"tool_call_id": "call_456"
});
let deserialized_tool: Message = serde_json::from_value(tool_json.clone()).unwrap();
assert_eq!(deserialized_tool.role, Role::Tool);
if let MessageContent::Text(content) = &deserialized_tool.content {
assert_eq!(content, "Weather is sunny");
} else {
panic!("Expected text content");
}
assert_eq!(deserialized_tool.tool_call_id, Some("call_456".to_string()));
assert!(deserialized_tool.tool_calls.is_none());
assert!(deserialized_tool.name.is_none());
let serialized_tool_json = serde_json::to_value(&deserialized_tool).unwrap();
assert_eq!(tool_json, serialized_tool_json);
// Test 4: ResponseMessage vs Message differences
let response_json = json!({
"role": "assistant",
"content": "Response content",
"annotations": [
{"type": "citation"}
]
});
let deserialized_response: ResponseMessage =
serde_json::from_value(response_json.clone()).unwrap();
assert_eq!(deserialized_response.role, Role::Assistant);
assert_eq!(
deserialized_response.content,
Some("Response content".to_string())
);
assert!(deserialized_response.annotations.is_some());
assert!(deserialized_response.refusal.is_none());
assert!(deserialized_response.function_call.is_none());
assert!(deserialized_response.tool_calls.is_none());
let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap();
assert_eq!(response_json, serialized_response_json);
// Test conversion from ResponseMessage to Message
let converted = deserialized_response.to_message();
assert_eq!(converted.role, Role::Assistant);
if let MessageContent::Text(text) = converted.content {
assert_eq!(text, "Response content");
} else {
panic!("Expected text content");
}
assert!(converted.name.is_none());
assert!(converted.tool_call_id.is_none());
}
#[test]
fn test_tool_choice_type_serialization() {
// Test that the enum serializes to the correct string values
let auto_choice = ToolChoice::Type(ToolChoiceType::Auto);
let required_choice = ToolChoice::Type(ToolChoiceType::Required);
let none_choice = ToolChoice::Type(ToolChoiceType::None);
let auto_json = serde_json::to_value(&auto_choice).unwrap();
let required_json = serde_json::to_value(&required_choice).unwrap();
let none_json = serde_json::to_value(&none_choice).unwrap();
assert_eq!(auto_json, "auto");
assert_eq!(required_json, "required");
assert_eq!(none_json, "none");
// Test deserialization from string values
let auto_deserialized: ToolChoice = serde_json::from_value(json!("auto")).unwrap();
let required_deserialized: ToolChoice = serde_json::from_value(json!("required")).unwrap();
let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap();
assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto));
assert_eq!(
required_deserialized,
ToolChoice::Type(ToolChoiceType::Required)
);
assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None));
// Test that invalid string values fail deserialization (type safety!)
let invalid_result: Result<ToolChoice, _> = serde_json::from_value(json!("invalid"));
assert!(invalid_result.is_err());
}
#[test]
fn test_chat_completions_response_with_service_tier() {
// Test that ChatCompletionsResponse can deserialize OpenAI responses with service_tier field
let json_response = r#"{
"id": "chatcmpl-CAJc2Df6QCc7Mv3RP0Cf2xlbDV1x2",
"object": "chat.completion",
"created": 1756574706,
"model": "gpt-4o-2024-08-06",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Test response content",
"annotations": []
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 65,
"completion_tokens": 184,
"total_tokens": 249,
"prompt_tokens_details": {
"cached_tokens": 0,
"audio_tokens": 0
},
"completion_tokens_details": {
"reasoning_tokens": 0,
"audio_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0
}
},
"service_tier": "default",
"system_fingerprint": "fp_f33640a400"
}"#;
let response: ChatCompletionsResponse = serde_json::from_str(json_response).unwrap();
assert_eq!(response.id, "chatcmpl-CAJc2Df6QCc7Mv3RP0Cf2xlbDV1x2");
assert_eq!(response.object.as_deref(), Some("chat.completion"));
assert_eq!(response.created, 1756574706);
assert_eq!(response.model, "gpt-4o-2024-08-06");
assert_eq!(response.service_tier, Some("default".to_string()));
assert_eq!(
response.system_fingerprint,
Some("fp_f33640a400".to_string())
);
assert_eq!(response.choices.len(), 1);
assert_eq!(response.usage.prompt_tokens, 65);
assert_eq!(response.usage.completion_tokens, 184);
assert_eq!(response.usage.total_tokens, 249);
}
#[test]
fn test_chat_completions_response_without_service_tier() {
// Test that ChatCompletionsResponse can deserialize responses without service_tier (backward compatibility)
let json_response = r#"{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Test response"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}"#;
let response: ChatCompletionsResponse = serde_json::from_str(json_response).unwrap();
assert_eq!(response.id, "chatcmpl-123");
assert_eq!(response.service_tier, None); // Should be None when not present
assert_eq!(response.system_fingerprint, None);
}
}