From b9ac7cbafd7ca824ff47d886bfb5ed2668079a24 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 14 Oct 2025 13:18:45 -0700 Subject: [PATCH] cargo fmt and return better json --- .../src/handlers/agent_chat_completions.rs | 37 +- .../src/handlers/chat_completions.rs | 7 +- .../src/handlers/response_handler.rs | 31 ++ crates/brightstaff/src/main.rs | 4 +- crates/brightstaff/src/utils/tracing.rs | 23 +- crates/common/src/routing.rs | 1 - crates/hermesllm/src/apis/anthropic.rs | 204 ++++--- crates/hermesllm/src/apis/mod.rs | 20 +- crates/hermesllm/src/apis/openai.rs | 109 ++-- crates/hermesllm/src/clients/endpoints.rs | 121 +++-- crates/hermesllm/src/clients/mod.rs | 4 +- crates/hermesllm/src/clients/transformer.rs | 499 +++++++++++------- crates/hermesllm/src/lib.rs | 73 +-- crates/hermesllm/src/providers/id.rs | 77 +-- crates/hermesllm/src/providers/mod.rs | 4 +- crates/hermesllm/src/providers/request.rs | 223 ++++---- crates/hermesllm/src/providers/response.rs | 189 ++++--- 17 files changed, 1023 insertions(+), 603 deletions(-) diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agent_chat_completions.rs index b0715f5d..a1a00f88 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agent_chat_completions.rs @@ -37,11 +37,38 @@ pub async fn agent_chat( match handle_agent_chat(request, router_service, agents_list, listeners).await { Ok(response) => Ok(response), Err(err) => { - warn!("Agent chat error: {}", err); - Ok(ResponseHandler::create_internal_error(&format!( - "Internal error: {}", - err - ))) + // Print detailed error information with full error chain + let mut error_chain = Vec::new(); + let mut current_error: &dyn std::error::Error = &err; + + // Collect the full error chain + loop { + error_chain.push(current_error.to_string()); + match current_error.source() { + Some(source) => current_error = source, + None => break, + } + } + + // Log the complete error chain + warn!("Agent chat error chain: {:#?}", error_chain); + warn!("Root error: {:?}", err); + + // Create structured error response as JSON + let error_json = serde_json::json!({ + "error": { + "type": "AgentFilterChainError", + "message": err.to_string(), + "error_chain": error_chain, + "debug_info": format!("{:?}", err) + } + }); + + // Log the error for debugging + info!("Structured error info: {}", error_json); + + // Return JSON error response + Ok(ResponseHandler::create_json_error_response(&error_json)) } } } diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index e906986f..b96f1f52 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -136,7 +136,10 @@ pub async fn chat( const MAX_MESSAGE_LENGTH: usize = 50; let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH { - let truncated: String = latest_message_for_log.chars().take(MAX_MESSAGE_LENGTH).collect(); + let truncated: String = latest_message_for_log + .chars() + .take(MAX_MESSAGE_LENGTH) + .collect(); format!("{}...", truncated) } else { latest_message_for_log @@ -162,7 +165,7 @@ pub async fn chat( Ok(route) => match route { Some((_, model_name)) => model_name, None => { - info!( + info!( "No route determined, using default model from request: {}", chat_completions_request_for_arch_router.model ); diff --git a/crates/brightstaff/src/handlers/response_handler.rs b/crates/brightstaff/src/handlers/response_handler.rs index 3d16a60c..2d647d2c 100644 --- a/crates/brightstaff/src/handlers/response_handler.rs +++ b/crates/brightstaff/src/handlers/response_handler.rs @@ -52,6 +52,20 @@ impl ResponseHandler { Self::create_error_response(StatusCode::INTERNAL_SERVER_ERROR, message) } + /// Create a JSON error response + pub fn create_json_error_response( + error_json: &serde_json::Value, + ) -> Response> { + let json_string = error_json.to_string(); + let mut response = Response::new(Self::create_full_body(json_string)); + *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + response.headers_mut().insert( + hyper::header::CONTENT_TYPE, + "application/json".parse().unwrap(), + ); + response + } + /// Create a streaming response from a reqwest response pub async fn create_streaming_response( &self, @@ -131,6 +145,23 @@ mod tests { assert_eq!(response.status(), StatusCode::NOT_FOUND); } + #[test] + fn test_create_json_error_response() { + let error_json = serde_json::json!({ + "error": { + "type": "TestError", + "message": "Test error message" + } + }); + + let response = ResponseHandler::create_json_error_response(&error_json); + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + response.headers().get("content-type").unwrap(), + "application/json" + ); + } + #[tokio::test] async fn test_create_streaming_response_with_mock() { use mockito::Server; diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 19b5004d..57dd9fe9 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -142,7 +142,9 @@ async fn main() -> Result<(), Box> { .with_context(parent_cx) .await } - (&Method::GET, "/v1/models" | "/agents/v1/models") => Ok(list_models(llm_providers).await), + (&Method::GET, "/v1/models" | "/agents/v1/models") => { + Ok(list_models(llm_providers).await) + } // hack for now to get openw-web-ui to work (&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => { let mut response = Response::new(empty()); diff --git a/crates/brightstaff/src/utils/tracing.rs b/crates/brightstaff/src/utils/tracing.rs index 7acb249a..6da4b631 100644 --- a/crates/brightstaff/src/utils/tracing.rs +++ b/crates/brightstaff/src/utils/tracing.rs @@ -1,20 +1,27 @@ -use std::sync::OnceLock; use std::fmt; +use std::sync::OnceLock; use opentelemetry::global; use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider}; use opentelemetry_stdout::SpanExporter; -use tracing_subscriber::EnvFilter; -use tracing_subscriber::fmt::{format, time::FormatTime, FmtContext, FormatEvent, FormatFields}; -use tracing::{Event, Subscriber}; use time::macros::format_description; +use tracing::{Event, Subscriber}; +use tracing_subscriber::fmt::{format, time::FormatTime, FmtContext, FormatEvent, FormatFields}; +use tracing_subscriber::EnvFilter; struct BracketedTime; impl FormatTime for BracketedTime { fn format_time(&self, w: &mut format::Writer<'_>) -> fmt::Result { let now = time::OffsetDateTime::now_utc(); - write!(w, "[{}]", now.format(&format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]")).unwrap()) + write!( + w, + "[{}]", + now.format(&format_description!( + "[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]" + )) + .unwrap() + ) } } @@ -34,7 +41,11 @@ where let timer = BracketedTime; timer.format_time(&mut writer)?; - write!(writer, "[{}] ", event.metadata().level().to_string().to_lowercase())?; + write!( + writer, + "[{}] ", + event.metadata().level().to_string().to_lowercase() + )?; ctx.field_format().format_fields(writer.by_ref(), event)?; diff --git a/crates/common/src/routing.rs b/crates/common/src/routing.rs index 2e9bac09..f4baf896 100644 --- a/crates/common/src/routing.rs +++ b/crates/common/src/routing.rs @@ -33,7 +33,6 @@ pub fn get_llm_provider( return provider; } - if llm_providers.default().is_some() { return llm_providers.default().unwrap(); } diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index abfde5b7..a261be3c 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -5,10 +5,10 @@ use serde_with::skip_serializing_none; use std::collections::HashMap; use super::ApiDefinition; +use crate::clients::transformer::ExtractText; use crate::providers::request::{ProviderRequest, ProviderRequestError}; use crate::providers::response::{ProviderResponse, ProviderStreamResponse}; -use crate::clients::transformer::ExtractText; -use crate::{MESSAGES_PATH}; +use crate::MESSAGES_PATH; // Enum for all supported Anthropic APIs #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -52,9 +52,7 @@ impl ApiDefinition for AnthropicApi { } fn all_variants() -> Vec { - vec![ - AnthropicApi::Messages, - ] + vec![AnthropicApi::Messages] } } @@ -100,7 +98,6 @@ pub struct McpServer { pub tool_configuration: Option, } - #[skip_serializing_none] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct MessagesRequest { @@ -121,10 +118,8 @@ pub struct MessagesRequest { pub stop_sequences: Option>, pub tools: Option>, pub tool_choice: Option, - } - // Messages API specific types #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(rename_all = "lowercase")] @@ -235,34 +230,21 @@ impl ExtractText for Vec { } } - #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "snake_case")] #[serde(tag = "type")] pub enum MessagesImageSource { - Base64 { - media_type: String, - data: String, - }, - Url { - url: String, - }, + Base64 { media_type: String, data: String }, + Url { url: String }, } #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "snake_case")] #[serde(tag = "type")] pub enum MessagesDocumentSource { - Base64 { - media_type: String, - data: String, - }, - Url { - url: String, - }, - File { - file_id: String, - }, + Base64 { media_type: String, data: String }, + Url { url: String }, + File { file_id: String }, } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -276,7 +258,7 @@ impl ExtractText for MessagesMessageContent { fn extract_text(&self) -> String { match self { MessagesMessageContent::Single(text) => text.clone(), - MessagesMessageContent::Blocks(parts) => parts.extract_text() + MessagesMessageContent::Blocks(parts) => parts.extract_text(), } } } @@ -320,7 +302,6 @@ pub struct MessagesToolChoice { pub disable_parallel_tool_use: Option, } - #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum MessagesStopReason { @@ -457,7 +438,11 @@ impl ProviderResponse for MessagesResponse { Some(self) } fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { - Some((self.usage.input_tokens as usize, self.usage.output_tokens as usize, (self.usage.input_tokens + self.usage.output_tokens) as usize)) + Some(( + self.usage.input_tokens as usize, + self.usage.output_tokens as usize, + (self.usage.input_tokens + self.usage.output_tokens) as usize, + )) } } @@ -535,7 +520,7 @@ impl ProviderRequest for MessagesRequest { } fn metadata(&self) -> &Option> { - return &self.metadata; + return &self.metadata; } fn remove_metadata_key(&mut self, key: &str) -> bool { @@ -572,13 +557,11 @@ impl MessagesRole { impl ProviderStreamResponse for MessagesStreamEvent { fn content_delta(&self) -> Option<&str> { match self { - MessagesStreamEvent::ContentBlockDelta { delta, .. } => { - match delta { - MessagesContentDelta::TextDelta { text } => Some(text), - MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking), - _ => None, - } - } + MessagesStreamEvent::ContentBlockDelta { delta, .. } => match delta { + MessagesContentDelta::TextDelta { text } => Some(text), + MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking), + _ => None, + }, _ => None, } } @@ -627,7 +610,8 @@ mod tests { }); // Deserialize JSON into MessagesRequest - let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap(); + let deserialized_request: MessagesRequest = + serde_json::from_value(original_json.clone()).unwrap(); // Validate required fields are properly set assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229"); @@ -687,7 +671,8 @@ mod tests { }); // Deserialize JSON into MessagesRequest - let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap(); + let deserialized_request: MessagesRequest = + serde_json::from_value(original_json.clone()).unwrap(); // Validate required fields assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229"); @@ -730,7 +715,10 @@ mod tests { assert_eq!(serialized_json["messages"], original_json["messages"]); assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]); assert_eq!(serialized_json["system"], original_json["system"]); - assert_eq!(serialized_json["service_tier"], original_json["service_tier"]); + assert_eq!( + serialized_json["service_tier"], + original_json["service_tier"] + ); assert_eq!(serialized_json["thinking"], original_json["thinking"]); assert_eq!(serialized_json["metadata"], original_json["metadata"]); @@ -818,7 +806,8 @@ mod tests { }); // Deserialize JSON into MessagesRequest - let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap(); + let deserialized_request: MessagesRequest = + serde_json::from_value(original_json.clone()).unwrap(); // Validate top-level fields assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229"); @@ -833,7 +822,10 @@ mod tests { // Validate text content block if let MessagesContentBlock::Text { text, .. } = &content_blocks[0] { - assert_eq!(text, "What can you see in this image and what's the weather like?"); + assert_eq!( + text, + "What can you see in this image and what's the weather like?" + ); } else { panic!("Expected text content block"); } @@ -861,20 +853,32 @@ mod tests { // Validate thinking content block if let MessagesContentBlock::Thinking { thinking, .. } = &content_blocks[0] { - assert_eq!(thinking, "Let me analyze the image and then check the weather..."); + assert_eq!( + thinking, + "Let me analyze the image and then check the weather..." + ); } else { panic!("Expected thinking content block"); } // Validate text content block if let MessagesContentBlock::Text { text, .. } = &content_blocks[1] { - assert_eq!(text, "I can see the image. Let me check the weather for you."); + assert_eq!( + text, + "I can see the image. Let me check the weather for you." + ); } else { panic!("Expected text content block"); } // Validate tool use content block - if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = content_blocks[2] { + if let MessagesContentBlock::ToolUse { + ref id, + ref name, + ref input, + .. + } = content_blocks[2] + { assert_eq!(id, "toolu_weather123"); assert_eq!(name, "get_weather"); assert_eq!(input["location"], "San Francisco, CA"); @@ -892,7 +896,10 @@ mod tests { let tool = &tools[0]; assert_eq!(tool.name, "get_weather"); - assert_eq!(tool.description, Some("Get current weather information for a location".to_string())); + assert_eq!( + tool.description, + Some("Get current weather information for a location".to_string()) + ); assert_eq!(tool.input_schema["type"], "object"); assert!(tool.input_schema["properties"]["location"].is_object()); @@ -938,10 +945,16 @@ mod tests { assert_eq!(deserialized_mcp.name, "test-server"); assert_eq!(deserialized_mcp.server_type, McpServerType::Url); assert_eq!(deserialized_mcp.url, "https://example.com/mcp"); - assert_eq!(deserialized_mcp.authorization_token, Some("secret-token".to_string())); + assert_eq!( + deserialized_mcp.authorization_token, + Some("secret-token".to_string()) + ); if let Some(tool_config) = &deserialized_mcp.tool_configuration { - assert_eq!(tool_config.allowed_tools, Some(vec!["tool1".to_string(), "tool2".to_string()])); + assert_eq!( + tool_config.allowed_tools, + Some(vec!["tool1".to_string(), "tool2".to_string()]) + ); assert_eq!(tool_config.enabled, Some(true)); } else { panic!("Expected tool configuration"); @@ -957,7 +970,8 @@ mod tests { "url": "https://minimal.com/mcp" }); - let deserialized_minimal: McpServer = serde_json::from_value(minimal_mcp_json.clone()).unwrap(); + let deserialized_minimal: McpServer = + serde_json::from_value(minimal_mcp_json.clone()).unwrap(); assert_eq!(deserialized_minimal.name, "minimal-server"); assert_eq!(deserialized_minimal.server_type, McpServerType::Url); assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp"); @@ -991,12 +1005,16 @@ mod tests { } }); - let deserialized_response: MessagesResponse = serde_json::from_value(response_json.clone()).unwrap(); + let deserialized_response: MessagesResponse = + serde_json::from_value(response_json.clone()).unwrap(); assert_eq!(deserialized_response.id, "msg_01ABC123"); assert_eq!(deserialized_response.obj_type, "message"); assert_eq!(deserialized_response.role, MessagesRole::Assistant); assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229"); - assert_eq!(deserialized_response.stop_reason, MessagesStopReason::EndTurn); + assert_eq!( + deserialized_response.stop_reason, + MessagesStopReason::EndTurn + ); assert!(deserialized_response.stop_sequence.is_none()); assert!(deserialized_response.container.is_none()); @@ -1011,7 +1029,10 @@ mod tests { // Check usage assert_eq!(deserialized_response.usage.input_tokens, 10); assert_eq!(deserialized_response.usage.output_tokens, 25); - assert_eq!(deserialized_response.usage.cache_creation_input_tokens, Some(5)); + assert_eq!( + deserialized_response.usage.cache_creation_input_tokens, + Some(5) + ); assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3)); let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap(); @@ -1027,7 +1048,8 @@ mod tests { } }); - let deserialized_event: MessagesStreamEvent = serde_json::from_value(stream_event_json.clone()).unwrap(); + let deserialized_event: MessagesStreamEvent = + serde_json::from_value(stream_event_json.clone()).unwrap(); if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event { assert_eq!(index, 0); if let MessagesContentDelta::TextDelta { text } = delta { @@ -1055,8 +1077,15 @@ mod tests { } }); - let deserialized_tool_use: MessagesContentBlock = serde_json::from_value(tool_use_json.clone()).unwrap(); - if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = deserialized_tool_use { + let deserialized_tool_use: MessagesContentBlock = + serde_json::from_value(tool_use_json.clone()).unwrap(); + if let MessagesContentBlock::ToolUse { + ref id, + ref name, + ref input, + .. + } = deserialized_tool_use + { assert_eq!(id, "toolu_01ABC123"); assert_eq!(name, "get_weather"); assert_eq!(input["location"], "San Francisco, CA"); @@ -1079,8 +1108,15 @@ mod tests { ] }); - let deserialized_tool_result: MessagesContentBlock = serde_json::from_value(tool_result_json.clone()).unwrap(); - if let MessagesContentBlock::ToolResult { ref tool_use_id, ref is_error, ref content, .. } = deserialized_tool_result { + let deserialized_tool_result: MessagesContentBlock = + serde_json::from_value(tool_result_json.clone()).unwrap(); + if let MessagesContentBlock::ToolResult { + ref tool_use_id, + ref is_error, + ref content, + .. + } = deserialized_tool_result + { assert_eq!(tool_use_id, "toolu_01ABC123"); assert!(is_error.is_none()); if let ToolResultContent::Blocks(blocks) = content { @@ -1229,7 +1265,8 @@ mod tests { }); // Deserialize the complex MessagesRequest - let deserialized_request: MessagesRequest = serde_json::from_value(complex_request_json.clone()).unwrap(); + let deserialized_request: MessagesRequest = + serde_json::from_value(complex_request_json.clone()).unwrap(); // Verify basic fields assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514"); @@ -1239,8 +1276,15 @@ mod tests { // Verify system message with cache_control if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system { assert_eq!(system_blocks.len(), 2); - if let MessagesContentBlock::Text { text, cache_control } = &system_blocks[0] { - assert_eq!(text, "You are Claude Code, Anthropic's official CLI for Claude."); + if let MessagesContentBlock::Text { + text, + cache_control, + } = &system_blocks[0] + { + assert_eq!( + text, + "You are Claude Code, Anthropic's official CLI for Claude." + ); assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral)); } else { panic!("Expected text system message with cache_control"); @@ -1253,7 +1297,13 @@ mod tests { let assistant_message = &deserialized_request.messages[1]; assert_eq!(assistant_message.role, MessagesRole::Assistant); if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content { - if let MessagesContentBlock::ToolUse { id, name, input, cache_control } = &content_blocks[0] { + if let MessagesContentBlock::ToolUse { + id, + name, + input, + cache_control, + } = &content_blocks[0] + { assert_eq!(id, "call_kV50LtJQKHvvzZui5TW56DUl"); assert_eq!(name, "TodoWrite"); assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral)); @@ -1272,7 +1322,12 @@ mod tests { let user_message = &deserialized_request.messages[2]; assert_eq!(user_message.role, MessagesRole::User); if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content { - if let MessagesContentBlock::ToolResult { tool_use_id, content, .. } = &content_blocks[0] { + if let MessagesContentBlock::ToolResult { + tool_use_id, + content, + .. + } = &content_blocks[0] + { assert_eq!(tool_use_id, "call_kV50LtJQKHvvzZui5TW56DUl"); if let ToolResultContent::Text(text) = content { assert!(text.contains("Todos have been modified successfully")); @@ -1284,7 +1339,11 @@ mod tests { } // Verify text content with cache_control - if let MessagesContentBlock::Text { text, cache_control } = &content_blocks[2] { + if let MessagesContentBlock::Text { + text, + cache_control, + } = &content_blocks[2] + { assert_eq!(text, "try again"); assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral)); } else { @@ -1296,11 +1355,15 @@ mod tests { // Test serialization round-trip let serialized_request = serde_json::to_value(&deserialized_request).unwrap(); - let re_deserialized_request: MessagesRequest = serde_json::from_value(serialized_request).unwrap(); + let re_deserialized_request: MessagesRequest = + serde_json::from_value(serialized_request).unwrap(); // Verify round-trip consistency assert_eq!(deserialized_request.model, re_deserialized_request.model); - assert_eq!(deserialized_request.messages.len(), re_deserialized_request.messages.len()); + assert_eq!( + deserialized_request.messages.len(), + re_deserialized_request.messages.len() + ); } #[test] @@ -1339,7 +1402,8 @@ mod tests { } }); - let deserialized_event: MessagesStreamEvent = serde_json::from_value(thinking_delta_json.clone()).unwrap(); + let deserialized_event: MessagesStreamEvent = + serde_json::from_value(thinking_delta_json.clone()).unwrap(); if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event { assert_eq!(index, 0); if let MessagesContentDelta::ThinkingDelta { thinking } = delta { @@ -1352,7 +1416,10 @@ mod tests { } // Test that thinking delta is returned by content_delta() - assert_eq!(deserialized_event.content_delta(), Some(".\n\nI need to consider:\n1. Current")); + assert_eq!( + deserialized_event.content_delta(), + Some(".\n\nI need to consider:\n1. Current") + ); let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap(); assert_eq!(thinking_delta_json, serialized_event_json); @@ -1376,7 +1443,8 @@ mod tests { } }); - let deserialized_request: MessagesRequest = serde_json::from_value(request_json.clone()).unwrap(); + let deserialized_request: MessagesRequest = + serde_json::from_value(request_json.clone()).unwrap(); assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514"); assert_eq!(deserialized_request.max_tokens, 2048); diff --git a/crates/hermesllm/src/apis/mod.rs b/crates/hermesllm/src/apis/mod.rs index b175988c..99158dfa 100644 --- a/crates/hermesllm/src/apis/mod.rs +++ b/crates/hermesllm/src/apis/mod.rs @@ -3,7 +3,6 @@ pub mod openai; pub use anthropic::*; pub use openai::*; - pub trait ApiDefinition { /// Returns the endpoint path for this API fn endpoint(&self) -> &'static str; @@ -49,11 +48,7 @@ mod tests { #[test] fn test_api_detection_from_endpoints() { // Test that we can detect APIs from endpoints using the trait - let endpoints = vec![ - CHAT_COMPLETIONS_PATH, - MESSAGES_PATH, - "/v1/unknown" - ]; + let endpoints = vec![CHAT_COMPLETIONS_PATH, MESSAGES_PATH, "/v1/unknown"]; let mut detected_apis = Vec::new(); @@ -67,11 +62,14 @@ mod tests { } } - assert_eq!(detected_apis, vec![ - "OpenAI: ChatCompletions", - "Anthropic: Messages", - "Unknown API" - ]); + assert_eq!( + detected_apis, + vec![ + "OpenAI: ChatCompletions", + "Anthropic: Messages", + "Unknown API" + ] + ); } #[test] diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 63b5fc58..58e4c8a5 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -5,11 +5,11 @@ use std::collections::HashMap; use std::fmt::Display; use thiserror::Error; +use super::ApiDefinition; +use crate::clients::transformer::ExtractText; use crate::providers::request::{ProviderRequest, ProviderRequestError}; use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage}; -use super::ApiDefinition; -use crate::clients::transformer::{ExtractText}; -use crate::{CHAT_COMPLETIONS_PATH}; +use crate::CHAT_COMPLETIONS_PATH; // ============================================================================ // OPENAI API ENUMERATION @@ -46,7 +46,7 @@ impl ApiDefinition for OpenAIApi { } fn supports_tools(&self) -> bool { - match self { + match self { OpenAIApi::ChatCompletions => true, } } @@ -58,9 +58,7 @@ impl ApiDefinition for OpenAIApi { } fn all_variants() -> Vec { - vec![ - OpenAIApi::ChatCompletions, - ] + vec![OpenAIApi::ChatCompletions] } } @@ -190,7 +188,9 @@ impl ResponseMessage { pub fn to_message(&self) -> Message { Message { role: self.role.clone(), - content: self.content.as_ref() + content: self + .content + .as_ref() .map(|s| MessageContent::Text(s.clone())) .unwrap_or(MessageContent::Text(String::new())), name: None, // Response messages don't have names in the same way request messages do @@ -215,7 +215,7 @@ impl ExtractText for MessageContent { fn extract_text(&self) -> String { match self { MessageContent::Text(text) => text.clone(), - MessageContent::Parts(parts) => parts.extract_text() + MessageContent::Parts(parts) => parts.extract_text(), } } } @@ -274,7 +274,6 @@ pub struct ImageUrl { /// A single message in a chat conversation - /// A tool call made by the assistant #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct ToolCall { @@ -374,7 +373,6 @@ pub enum StaticContentType { Parts(Vec), } - /// Chat completions API response #[skip_serializing_none] #[derive(Serialize, Deserialize, Debug, Clone)] @@ -496,7 +494,6 @@ pub struct ChatCompletionsStreamResponse { pub service_tier: Option, } - /// A choice in a streaming response #[skip_serializing_none] #[derive(Serialize, Deserialize, Debug, Clone)] @@ -566,7 +563,6 @@ pub struct Models { pub data: Vec, } - // Error type for streaming operations #[derive(Debug, thiserror::Error)] pub enum OpenAIStreamError { @@ -597,13 +593,13 @@ pub enum OpenAIError { /// Trait Implementations /// =========================================================================== - /// Parameterized conversion for ChatCompletionsRequest impl TryFrom<&[u8]> for ChatCompletionsRequest { type Error = OpenAIStreamError; fn try_from(bytes: &[u8]) -> Result { - let mut req: ChatCompletionsRequest = serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?; + let mut req: ChatCompletionsRequest = + serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?; // Use the centralized suppression logic req.suppress_max_tokens_if_o3(); req.fix_temperature_if_gpt5(); @@ -651,13 +647,18 @@ impl ProviderRequest for ChatCompletionsRequest { fn extract_messages_text(&self) -> String { self.messages.iter().fold(String::new(), |acc, m| { - acc + " " + &match &m.content { - MessageContent::Text(text) => text.clone(), - MessageContent::Parts(parts) => parts.iter().map(|part| match part { - ContentPart::Text { text } => text.clone(), - ContentPart::ImageUrl { .. } => "[Image]".to_string(), - }).collect::>().join(" ") - } + acc + " " + + &match &m.content { + MessageContent::Text(text) => text.clone(), + MessageContent::Parts(parts) => parts + .iter() + .map(|part| match part { + ContentPart::Text { text } => text.clone(), + ContentPart::ImageUrl { .. } => "[Image]".to_string(), + }) + .collect::>() + .join(" "), + } }) } @@ -721,14 +722,14 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse { } fn role(&self) -> Option<&str> { - self.choices - .first() - .and_then(|choice| choice.delta.role.as_ref().map(|r| match r { + self.choices.first().and_then(|choice| { + choice.delta.role.as_ref().map(|r| match r { Role::System => "system", Role::User => "user", Role::Assistant => "assistant", Role::Tool => "tool", - })) + }) + }) } fn event_type(&self) -> Option<&str> { @@ -736,7 +737,6 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse { } } - #[cfg(test)] mod tests { use super::*; @@ -756,7 +756,8 @@ mod tests { }); // Deserialize JSON into ChatCompletionsRequest - let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap(); + let deserialized_request: ChatCompletionsRequest = + serde_json::from_value(original_json.clone()).unwrap(); // Validate required fields are properly set assert_eq!(deserialized_request.model, "gpt-4"); @@ -799,7 +800,8 @@ mod tests { }); // Deserialize JSON into ChatCompletionsRequest - let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap(); + let deserialized_request: ChatCompletionsRequest = + serde_json::from_value(original_json.clone()).unwrap(); // Validate required fields assert_eq!(deserialized_request.model, "gpt-4"); @@ -836,7 +838,10 @@ mod tests { assert_eq!(serialized_json["messages"], original_json["messages"]); assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]); assert_eq!(serialized_json["stream"], original_json["stream"]); - assert_eq!(serialized_json["stream_options"], original_json["stream_options"]); + assert_eq!( + serialized_json["stream_options"], + original_json["stream_options"] + ); assert_eq!(serialized_json["metadata"], original_json["metadata"]); // Handle temperature with floating point tolerance @@ -917,7 +922,8 @@ mod tests { }); // Deserialize JSON into ChatCompletionsRequest - let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap(); + let deserialized_request: ChatCompletionsRequest = + serde_json::from_value(original_json.clone()).unwrap(); // Validate top-level fields assert_eq!(deserialized_request.model, "gpt-4-vision-preview"); @@ -953,7 +959,10 @@ mod tests { let assistant_message = &deserialized_request.messages[1]; assert_eq!(assistant_message.role, Role::Assistant); if let MessageContent::Text(text) = &assistant_message.content { - assert_eq!(text, "I can see a beautiful cityscape. Let me check the weather for you."); + assert_eq!( + text, + "I can see a beautiful cityscape. Let me check the weather for you." + ); } else { panic!("Expected text content for assistant message"); } @@ -967,7 +976,10 @@ mod tests { assert_eq!(tool_call.id, "call_weather123"); assert_eq!(tool_call.call_type, "function"); assert_eq!(tool_call.function.name, "get_weather"); - assert_eq!(tool_call.function.arguments, "{\"location\": \"New York, NY\"}"); + assert_eq!( + tool_call.function.arguments, + "{\"location\": \"New York, NY\"}" + ); // Validate third message (tool response) let tool_message = &deserialized_request.messages[2]; @@ -977,7 +989,10 @@ mod tests { } else { panic!("Expected text content for tool message"); } - assert_eq!(tool_message.tool_call_id, Some("call_weather123".to_string())); + assert_eq!( + tool_message.tool_call_id, + Some("call_weather123".to_string()) + ); // Validate tools array assert!(deserialized_request.tools.is_some()); @@ -987,7 +1002,10 @@ mod tests { let tool = &tools[0]; assert_eq!(tool.tool_type, "function"); assert_eq!(tool.function.name, "get_weather"); - assert_eq!(tool.function.description, Some("Get current weather information for a location".to_string())); + assert_eq!( + tool.function.description, + Some("Get current weather information for a location".to_string()) + ); assert_eq!(tool.function.strict, Some(true)); // Validate tool parameters schema @@ -1093,7 +1111,8 @@ mod tests { ] }); - let deserialized_assistant: Message = serde_json::from_value(assistant_json.clone()).unwrap(); + let deserialized_assistant: Message = + serde_json::from_value(assistant_json.clone()).unwrap(); assert_eq!(deserialized_assistant.role, Role::Assistant); if let MessageContent::Text(content) = &deserialized_assistant.content { assert_eq!(content, "I'll help with that."); @@ -1142,9 +1161,13 @@ mod tests { ] }); - let deserialized_response: ResponseMessage = serde_json::from_value(response_json.clone()).unwrap(); + let deserialized_response: ResponseMessage = + serde_json::from_value(response_json.clone()).unwrap(); assert_eq!(deserialized_response.role, Role::Assistant); - assert_eq!(deserialized_response.content, Some("Response content".to_string())); + assert_eq!( + deserialized_response.content, + Some("Response content".to_string()) + ); assert!(deserialized_response.annotations.is_some()); assert!(deserialized_response.refusal.is_none()); assert!(deserialized_response.function_call.is_none()); @@ -1186,7 +1209,10 @@ mod tests { let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap(); assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto)); - assert_eq!(required_deserialized, ToolChoice::Type(ToolChoiceType::Required)); + assert_eq!( + required_deserialized, + ToolChoice::Type(ToolChoiceType::Required) + ); assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None)); // Test that invalid string values fail deserialization (type safety!) @@ -1237,7 +1263,10 @@ mod tests { assert_eq!(response.created, 1756574706); assert_eq!(response.model, "gpt-4o-2024-08-06"); assert_eq!(response.service_tier, Some("default".to_string())); - assert_eq!(response.system_fingerprint, Some("fp_f33640a400".to_string())); + assert_eq!( + response.system_fingerprint, + Some("fp_f33640a400".to_string()) + ); assert_eq!(response.choices.len(), 1); assert_eq!(response.usage.prompt_tokens, 65); assert_eq!(response.usage.completion_tokens, 184); diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index e5c01f05..263ca674 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -21,7 +21,10 @@ //! assert!(endpoints.contains(&"/v1/messages")); //! ``` -use crate::{apis::{AnthropicApi, ApiDefinition, OpenAIApi}, ProviderId}; +use crate::{ + apis::{AnthropicApi, ApiDefinition, OpenAIApi}, + ProviderId, +}; use std::fmt; /// Unified enum representing all supported API endpoints across providers @@ -34,8 +37,12 @@ pub enum SupportedAPIs { impl fmt::Display for SupportedAPIs { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - SupportedAPIs::OpenAIChatCompletions(api) => write!(f, "OpenAI API ({})", api.endpoint()), - SupportedAPIs::AnthropicMessagesAPI(api) => write!(f, "Anthropic API ({})", api.endpoint()), + SupportedAPIs::OpenAIChatCompletions(api) => { + write!(f, "OpenAI API ({})", api.endpoint()) + } + SupportedAPIs::AnthropicMessagesAPI(api) => { + write!(f, "Anthropic API ({})", api.endpoint()) + } } } } @@ -62,61 +69,60 @@ impl SupportedAPIs { } } - pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str, model_id: &str) -> String { + pub fn target_endpoint_for_provider( + &self, + provider_id: &ProviderId, + request_path: &str, + model_id: &str, + ) -> String { let default_endpoint = "/v1/chat/completions".to_string(); match self { - SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => { - match provider_id { - ProviderId::Anthropic => "/v1/messages".to_string(), - _ => default_endpoint, + SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id { + ProviderId::Anthropic => "/v1/messages".to_string(), + _ => default_endpoint, + }, + _ => match provider_id { + ProviderId::Groq => { + if request_path.starts_with("/v1/") { + format!("/openai{}", request_path) + } else { + default_endpoint + } } - } - _ => { - match provider_id { - ProviderId::Groq => { - if request_path.starts_with("/v1/") { - format!("/openai{}", request_path) - } else { - default_endpoint - } + ProviderId::Zhipu => { + if request_path.starts_with("/v1/") { + "/api/paas/v4/chat/completions".to_string() + } else { + default_endpoint } - ProviderId::Zhipu => { - if request_path.starts_with("/v1/") { - "/api/paas/v4/chat/completions".to_string() - } else { - default_endpoint - } - } - ProviderId::Qwen => { - if request_path.starts_with("/v1/") { - "/compatible-mode/v1/chat/completions".to_string() - } else { - default_endpoint - } - } - ProviderId::AzureOpenAI => { - if request_path.starts_with("/v1/") { - format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id) - } else { - default_endpoint - } - } - ProviderId::Gemini => { - if request_path.starts_with("/v1/") { - "/v1beta/openai/chat/completions".to_string() - } else { - default_endpoint - } - } - _ => default_endpoint, } - } + ProviderId::Qwen => { + if request_path.starts_with("/v1/") { + "/compatible-mode/v1/chat/completions".to_string() + } else { + default_endpoint + } + } + ProviderId::AzureOpenAI => { + if request_path.starts_with("/v1/") { + format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id) + } else { + default_endpoint + } + } + ProviderId::Gemini => { + if request_path.starts_with("/v1/") { + "/v1beta/openai/chat/completions".to_string() + } else { + default_endpoint + } + } + _ => default_endpoint, + }, } } } - - /// Get all supported endpoint paths pub fn supported_endpoints() -> Vec<&'static str> { let mut endpoints = Vec::new(); @@ -196,15 +202,26 @@ mod tests { // All OpenAI endpoints should be in the result for endpoint in openai_endpoints { - assert!(endpoints.contains(&endpoint), "Missing OpenAI endpoint: {}", endpoint); + assert!( + endpoints.contains(&endpoint), + "Missing OpenAI endpoint: {}", + endpoint + ); } // All Anthropic endpoints should be in the result for endpoint in anthropic_endpoints { - assert!(endpoints.contains(&endpoint), "Missing Anthropic endpoint: {}", endpoint); + assert!( + endpoints.contains(&endpoint), + "Missing Anthropic endpoint: {}", + endpoint + ); } // Total should match - assert_eq!(endpoints.len(), OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len()); + assert_eq!( + endpoints.len(), + OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len() + ); } } diff --git a/crates/hermesllm/src/clients/mod.rs b/crates/hermesllm/src/clients/mod.rs index 73972445..b93f910e 100644 --- a/crates/hermesllm/src/clients/mod.rs +++ b/crates/hermesllm/src/clients/mod.rs @@ -1,9 +1,9 @@ +pub mod endpoints; pub mod lib; pub mod transformer; -pub mod endpoints; // Re-export the main items for easier access +pub use endpoints::{identify_provider, SupportedAPIs}; pub use lib::*; -pub use endpoints::{SupportedAPIs, identify_provider}; // Note: transformer module contains TryFrom trait implementations that are automatically available diff --git a/crates/hermesllm/src/clients/transformer.rs b/crates/hermesllm/src/clients/transformer.rs index f6e508d4..11caae6f 100644 --- a/crates/hermesllm/src/clients/transformer.rs +++ b/crates/hermesllm/src/clients/transformer.rs @@ -42,10 +42,10 @@ //! # Ok::<(), Box>(()) //! ``` +use super::TransformError; +use crate::apis::*; use serde_json::Value; use std::time::{SystemTime, UNIX_EPOCH}; -use crate::apis::*; -use super::TransformError; // ============================================================================ // CONSTANTS @@ -66,7 +66,9 @@ pub trait ExtractText { /// Trait for utility functions on content collections trait ContentUtils { fn extract_tool_calls(&self) -> Result>, TransformError>; - fn split_for_openai(&self) -> Result<(Vec, Vec, Vec<(String, String, bool)>), TransformError>; + fn split_for_openai( + &self, + ) -> Result<(Vec, Vec, Vec<(String, String, bool)>), TransformError>; } // ============================================================================ @@ -75,7 +77,6 @@ trait ContentUtils { type AnthropicMessagesRequest = MessagesRequest; - impl TryFrom for ChatCompletionsRequest { type Error = TransformError; @@ -95,7 +96,8 @@ impl TryFrom for ChatCompletionsRequest { // Convert tools and tool choice let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools)); - let (openai_tool_choice, parallel_tool_calls) = convert_anthropic_tool_choice(req.tool_choice); + let (openai_tool_choice, parallel_tool_calls) = + convert_anthropic_tool_choice(req.tool_choice); let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest { model: req.model, @@ -137,13 +139,15 @@ impl TryFrom for AnthropicMessagesRequest { // Convert tools and tool choice let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools)); - let anthropic_tool_choice = convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls); + let anthropic_tool_choice = + convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls); Ok(AnthropicMessagesRequest { model: req.model, system: system_prompt, messages, - max_tokens: req.max_completion_tokens + max_tokens: req + .max_completion_tokens .or(req.max_tokens) .unwrap_or(DEFAULT_MAX_TOKENS), container: None, @@ -179,7 +183,11 @@ impl TryFrom for ChatCompletionsResponse { MessageContent::Text(text) => Some(text), MessageContent::Parts(parts) => { let text = parts.extract_text(); - if text.is_empty() { None } else { Some(text) } + if text.is_empty() { + None + } else { + Some(text) + } } }; @@ -225,11 +233,15 @@ impl TryFrom for MessagesResponse { type Error = TransformError; fn try_from(resp: ChatCompletionsResponse) -> Result { - let choice = resp.choices.into_iter().next() + let choice = resp + .choices + .into_iter() + .next() .ok_or_else(|| TransformError::MissingField("choices".to_string()))?; let content = convert_openai_message_to_anthropic_content(&choice.message.to_message())?; - let stop_reason = choice.finish_reason + let stop_reason = choice + .finish_reason .map(|fr| fr.into()) .unwrap_or(MessagesStopReason::EndTurn); @@ -263,33 +275,27 @@ impl TryFrom for ChatCompletionsStreamResponse { fn try_from(event: MessagesStreamEvent) -> Result { match event { - MessagesStreamEvent::MessageStart { message } => { - Ok(create_openai_chunk( - &message.id, - &message.model, - MessageDelta { - role: Some(Role::Assistant), - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )) - } + MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk( + &message.id, + &message.model, + MessageDelta { + role: Some(Role::Assistant), + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), MessagesStreamEvent::ContentBlockStart { content_block, .. } => { convert_content_block_start(content_block) } - MessagesStreamEvent::ContentBlockDelta { delta, .. } => { - convert_content_delta(delta) - } + MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta), - MessagesStreamEvent::ContentBlockStop { .. } => { - Ok(create_empty_openai_chunk()) - } + MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()), MessagesStreamEvent::MessageDelta { delta, usage } => { let finish_reason: Option = Some(delta.stop_reason.into()); @@ -310,34 +316,30 @@ impl TryFrom for ChatCompletionsStreamResponse { )) } - MessagesStreamEvent::MessageStop => { - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - Some(FinishReason::Stop), - None, - )) - } + MessagesStreamEvent::MessageStop => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + Some(FinishReason::Stop), + None, + )), - MessagesStreamEvent::Ping => { - Ok(ChatCompletionsStreamResponse { - id: "stream".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: current_timestamp(), - model: "unknown".to_string(), - choices: vec![], - usage: None, - system_fingerprint: None, - service_tier: None, - }) - } + MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse { + id: "stream".to_string(), + object: Some("chat.completion.chunk".to_string()), + created: current_timestamp(), + model: "unknown".to_string(), + choices: vec![], + usage: None, + system_fingerprint: None, + service_tier: None, + }), } } } @@ -442,9 +444,7 @@ impl Into for MessagesSystemPrompt { fn into(self) -> Message { let system_content = match self { MessagesSystemPrompt::Single(text) => MessageContent::Text(text), - MessagesSystemPrompt::Blocks(blocks) => { - MessageContent::Text(blocks.extract_text()) - } + MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()), }; Message { @@ -461,7 +461,7 @@ impl Into for Message { fn into(self) -> MessagesSystemPrompt { let system_text = match self.content { MessageContent::Text(text) => text, - MessageContent::Parts(parts) => parts.extract_text() + MessageContent::Parts(parts) => parts.extract_text(), }; MessagesSystemPrompt::Single(system_text) } @@ -505,7 +505,11 @@ impl TryFrom for Vec { role: message.role.into(), content, name: None, - tool_calls: if tool_calls.is_empty() { None } else { Some(tool_calls) }, + tool_calls: if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }, tool_call_id: None, }; result.push(main_message); @@ -526,8 +530,11 @@ impl TryFrom for MessagesMessage { Role::Assistant => MessagesRole::Assistant, Role::Tool => { // Tool messages become user messages with tool results - let tool_call_id = message.tool_call_id - .ok_or_else(|| TransformError::MissingField("tool_call_id required for Tool messages".to_string()))?; + let tool_call_id = message.tool_call_id.ok_or_else(|| { + TransformError::MissingField( + "tool_call_id required for Tool messages".to_string(), + ) + })?; return Ok(MessagesMessage { role: MessagesRole::User, @@ -545,7 +552,9 @@ impl TryFrom for MessagesMessage { }); } Role::System => { - return Err(TransformError::UnsupportedConversion("System messages should be handled separately".to_string())); + return Err(TransformError::UnsupportedConversion( + "System messages should be handled separately".to_string(), + )); } }; @@ -573,24 +582,36 @@ impl ContentUtils for Vec { for block in self { match block { - MessagesContentBlock::ToolUse { id, name, input, .. } | - MessagesContentBlock::ServerToolUse { id, name, input } | - MessagesContentBlock::McpToolUse { id, name, input } => { + MessagesContentBlock::ToolUse { + id, name, input, .. + } + | MessagesContentBlock::ServerToolUse { id, name, input } + | MessagesContentBlock::McpToolUse { id, name, input } => { let arguments = serde_json::to_string(&input)?; tool_calls.push(ToolCall { id: id.clone(), call_type: "function".to_string(), - function: FunctionCall { name: name.clone(), arguments }, + function: FunctionCall { + name: name.clone(), + arguments, + }, }); } _ => continue, } } - Ok(if tool_calls.is_empty() { None } else { Some(tool_calls) }) + Ok(if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }) } - fn split_for_openai(&self) -> Result<(Vec, Vec, Vec<(String, String, bool)>), TransformError> { + fn split_for_openai( + &self, + ) -> Result<(Vec, Vec, Vec<(String, String, bool)>), TransformError> + { let mut content_parts = Vec::new(); let mut tool_calls = Vec::new(); let mut tool_results = Vec::new(); @@ -609,25 +630,55 @@ impl ContentUtils for Vec { }, }); } - MessagesContentBlock::ToolUse { id, name, input, .. } | - MessagesContentBlock::ServerToolUse { id, name, input } | - MessagesContentBlock::McpToolUse { id, name, input } => { + MessagesContentBlock::ToolUse { + id, name, input, .. + } + | MessagesContentBlock::ServerToolUse { id, name, input } + | MessagesContentBlock::McpToolUse { id, name, input } => { let arguments = serde_json::to_string(&input)?; tool_calls.push(ToolCall { id: id.clone(), call_type: "function".to_string(), - function: FunctionCall { name: name.clone(), arguments }, + function: FunctionCall { + name: name.clone(), + arguments, + }, }); } - MessagesContentBlock::ToolResult { tool_use_id, content, is_error, .. } => { + MessagesContentBlock::ToolResult { + tool_use_id, + content, + is_error, + .. + } => { let result_text = content.extract_text(); - tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false))); + tool_results.push(( + tool_use_id.clone(), + result_text, + is_error.unwrap_or(false), + )); } - MessagesContentBlock::WebSearchToolResult { tool_use_id, content, is_error } | - MessagesContentBlock::CodeExecutionToolResult { tool_use_id, content, is_error } | - MessagesContentBlock::McpToolResult { tool_use_id, content, is_error } => { + MessagesContentBlock::WebSearchToolResult { + tool_use_id, + content, + is_error, + } + | MessagesContentBlock::CodeExecutionToolResult { + tool_use_id, + content, + is_error, + } + | MessagesContentBlock::McpToolResult { + tool_use_id, + content, + is_error, + } => { let result_text = content.extract_text(); - tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false))); + tool_results.push(( + tool_use_id.clone(), + result_text, + is_error.unwrap_or(false), + )); } _ => { // Skip unsupported content types @@ -696,7 +747,10 @@ impl Into for Usage { /// Helper to create a current unix timestamp fn current_timestamp() -> u64 { - SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() } /// Helper to create OpenAI streaming chunk @@ -705,7 +759,7 @@ fn create_openai_chunk( model: &str, delta: MessageDelta, finish_reason: Option, - usage: Option + usage: Option, ) -> ChatCompletionsStreamResponse { ChatCompletionsStreamResponse { id: id.to_string(), @@ -743,7 +797,8 @@ fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse { /// Convert Anthropic tools to OpenAI format fn convert_anthropic_tools(tools: Vec) -> Vec { - tools.into_iter() + tools + .into_iter() .map(|tool| Tool { tool_type: "function".to_string(), function: Function { @@ -758,7 +813,8 @@ fn convert_anthropic_tools(tools: Vec) -> Vec { /// Convert OpenAI tools to Anthropic format fn convert_openai_tools(tools: Vec) -> Vec { - tools.into_iter() + tools + .into_iter() .map(|tool| MessagesTool { name: tool.function.name, description: tool.function.description, @@ -768,7 +824,9 @@ fn convert_openai_tools(tools: Vec) -> Vec { } /// Convert Anthropic tool choice to OpenAI format -fn convert_anthropic_tool_choice(tool_choice: Option) -> (Option, Option) { +fn convert_anthropic_tool_choice( + tool_choice: Option, +) -> (Option, Option) { match tool_choice { Some(choice) => { let openai_choice = match choice.kind { @@ -789,45 +847,46 @@ fn convert_anthropic_tool_choice(tool_choice: Option) -> (Op let parallel = choice.disable_parallel_tool_use.map(|disable| !disable); (Some(openai_choice), parallel) } - None => (None, None) + None => (None, None), } } /// Convert OpenAI tool choice to Anthropic format fn convert_openai_tool_choice( tool_choice: Option, - parallel_tool_calls: Option + parallel_tool_calls: Option, ) -> Option { - tool_choice.map(|choice| { - match choice { - ToolChoice::Type(tool_type) => match tool_type { - ToolChoiceType::Auto => MessagesToolChoice { - kind: MessagesToolChoiceType::Auto, - name: None, - disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), - }, - ToolChoiceType::Required => MessagesToolChoice { - kind: MessagesToolChoiceType::Any, - name: None, - disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), - }, - ToolChoiceType::None => MessagesToolChoice { - kind: MessagesToolChoiceType::None, - name: None, - disable_parallel_tool_use: None, - }, - }, - ToolChoice::Function { function, .. } => MessagesToolChoice { - kind: MessagesToolChoiceType::Tool, - name: Some(function.name), + tool_choice.map(|choice| match choice { + ToolChoice::Type(tool_type) => match tool_type { + ToolChoiceType::Auto => MessagesToolChoice { + kind: MessagesToolChoiceType::Auto, + name: None, disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), }, - } + ToolChoiceType::Required => MessagesToolChoice { + kind: MessagesToolChoiceType::Any, + name: None, + disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), + }, + ToolChoiceType::None => MessagesToolChoice { + kind: MessagesToolChoiceType::None, + name: None, + disable_parallel_tool_use: None, + }, + }, + ToolChoice::Function { function, .. } => MessagesToolChoice { + kind: MessagesToolChoiceType::Tool, + name: Some(function.name), + disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), + }, }) } /// Build OpenAI message content from parts and tool calls -fn build_openai_content(content_parts: Vec, tool_calls: &[ToolCall]) -> MessageContent { +fn build_openai_content( + content_parts: Vec, + tool_calls: &[ToolCall], +) -> MessageContent { if content_parts.len() == 1 && tool_calls.is_empty() { match &content_parts[0] { ContentPart::Text { text } => MessageContent::Text(text.clone()), @@ -855,7 +914,9 @@ fn build_anthropic_content(content_blocks: Vec) -> Message } /// Convert Anthropic content blocks to OpenAI message content -fn convert_anthropic_content_to_openai(content: &[MessagesContentBlock]) -> Result { +fn convert_anthropic_content_to_openai( + content: &[MessagesContentBlock], +) -> Result { let mut text_parts = Vec::new(); for block in content { @@ -877,21 +938,29 @@ fn convert_anthropic_content_to_openai(content: &[MessagesContentBlock]) -> Resu } /// Convert OpenAI message to Anthropic content blocks -fn convert_openai_message_to_anthropic_content(message: &Message) -> Result, TransformError> { +fn convert_openai_message_to_anthropic_content( + message: &Message, +) -> Result, TransformError> { let mut blocks = Vec::new(); // Handle regular content match &message.content { MessageContent::Text(text) => { if !text.is_empty() { - blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None }); + blocks.push(MessagesContentBlock::Text { + text: text.clone(), + cache_control: None, + }); } } MessageContent::Parts(parts) => { for part in parts { match part { ContentPart::Text { text } => { - blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None }); + blocks.push(MessagesContentBlock::Text { + text: text.clone(), + cache_control: None, + }); } ContentPart::ImageUrl { image_url } => { let source = convert_image_url_to_source(image_url); @@ -947,23 +1016,29 @@ fn convert_image_url_to_source(image_url: &ImageUrl) -> MessagesImageSource { data: data.to_string(), } } else { - MessagesImageSource::Url { url: image_url.url.clone() } + MessagesImageSource::Url { + url: image_url.url.clone(), + } } } else { - MessagesImageSource::Url { url: image_url.url.clone() } + MessagesImageSource::Url { + url: image_url.url.clone(), + } } } /// Convert content block start to OpenAI chunk -fn convert_content_block_start(content_block: MessagesContentBlock) -> Result { +fn convert_content_block_start( + content_block: MessagesContentBlock, +) -> Result { match content_block { MessagesContentBlock::Text { .. } => { // No immediate output for text block start Ok(create_empty_openai_chunk()) } - MessagesContentBlock::ToolUse { id, name, .. } | - MessagesContentBlock::ServerToolUse { id, name, .. } | - MessagesContentBlock::McpToolUse { id, name, .. } => { + MessagesContentBlock::ToolUse { id, name, .. } + | MessagesContentBlock::ServerToolUse { id, name, .. } + | MessagesContentBlock::McpToolUse { id, name, .. } => { // Tool use start → OpenAI chunk with tool_calls Ok(create_openai_chunk( "stream", @@ -987,71 +1062,71 @@ fn convert_content_block_start(content_block: MessagesContentBlock) -> Result Err(TransformError::UnsupportedContent("Unsupported content block type in stream start".to_string())), + _ => Err(TransformError::UnsupportedContent( + "Unsupported content block type in stream start".to_string(), + )), } } /// Convert content delta to OpenAI chunk -fn convert_content_delta(delta: MessagesContentDelta) -> Result { +fn convert_content_delta( + delta: MessagesContentDelta, +) -> Result { match delta { - MessagesContentDelta::TextDelta { text } => { - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: Some(text), - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )) - } - MessagesContentDelta::ThinkingDelta { thinking } => { - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: Some(format!("thinking: {}", thinking)), - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )) - } - MessagesContentDelta::InputJsonDelta { partial_json } => { - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: 0, - id: None, - call_type: None, - function: Some(FunctionCallDelta { - name: None, - arguments: Some(partial_json), - }), - }]), - }, - None, - None, - )) - } + MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(text), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(format!("thinking: {}", thinking)), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: 0, + id: None, + call_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: Some(partial_json), + }), + }]), + }, + None, + None, + )), } } /// Convert tool call deltas to Anthropic stream events -fn convert_tool_call_deltas(tool_calls: Vec) -> Result { +fn convert_tool_call_deltas( + tool_calls: Vec, +) -> Result { for tool_call in tool_calls { if let Some(id) = &tool_call.id { // Tool call start @@ -1160,11 +1235,20 @@ mod tests { // Check key fields are preserved assert_eq!(original_anthropic.model, roundtrip_anthropic.model); - assert_eq!(original_anthropic.max_tokens, roundtrip_anthropic.max_tokens); - assert_eq!(original_anthropic.temperature, roundtrip_anthropic.temperature); + assert_eq!( + original_anthropic.max_tokens, + roundtrip_anthropic.max_tokens + ); + assert_eq!( + original_anthropic.temperature, + roundtrip_anthropic.temperature + ); assert_eq!(original_anthropic.top_p, roundtrip_anthropic.top_p); assert_eq!(original_anthropic.stream, roundtrip_anthropic.stream); - assert_eq!(original_anthropic.messages.len(), roundtrip_anthropic.messages.len()); + assert_eq!( + original_anthropic.messages.len(), + roundtrip_anthropic.messages.len() + ); } #[test] @@ -1308,7 +1392,10 @@ mod tests { let tool_calls = choice.delta.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls[0].id, Some("call_123".to_string())); - assert_eq!(tool_calls[0].function.as_ref().unwrap().name, Some("get_weather".to_string())); + assert_eq!( + tool_calls[0].function.as_ref().unwrap().name, + Some("get_weather".to_string()) + ); } #[test] @@ -1328,7 +1415,10 @@ mod tests { let tool_calls = choice.delta.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 1); - assert_eq!(tool_calls[0].function.as_ref().unwrap().arguments, Some(r#"{"location": "San Francisco"#.to_string())); + assert_eq!( + tool_calls[0].function.as_ref().unwrap().arguments, + Some(r#"{"location": "San Francisco"#.to_string()) + ); } #[test] @@ -1491,7 +1581,10 @@ mod tests { let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); match anthropic_event { - MessagesStreamEvent::ContentBlockStart { index, content_block } => { + MessagesStreamEvent::ContentBlockStart { + index, + content_block, + } => { assert_eq!(index, 0); match content_block { MessagesContentBlock::ToolUse { id, name, .. } => { @@ -1634,16 +1727,28 @@ mod tests { // Verify tool start let tool_calls = &openai_start.choices[0].delta.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls[0].id, Some("call_weather".to_string())); - assert_eq!(tool_calls[0].function.as_ref().unwrap().name, Some("get_weather".to_string())); + assert_eq!( + tool_calls[0].function.as_ref().unwrap().name, + Some("get_weather".to_string()) + ); // Verify argument deltas let args1 = &openai_delta1.choices[0].delta.tool_calls.as_ref().unwrap()[0] - .function.as_ref().unwrap().arguments; + .function + .as_ref() + .unwrap() + .arguments; assert_eq!(args1, &Some(r#"{"location": "#.to_string())); let args2 = &openai_delta2.choices[0].delta.tool_calls.as_ref().unwrap()[0] - .function.as_ref().unwrap().arguments; - assert_eq!(args2, &Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string())); + .function + .as_ref() + .unwrap() + .arguments; + assert_eq!( + args2, + &Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string()) + ); } #[test] @@ -1671,14 +1776,23 @@ mod tests { }; let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - assert_eq!(openai_resp.choices[0].finish_reason, Some(expected_openai_reason)); + assert_eq!( + openai_resp.choices[0].finish_reason, + Some(expected_openai_reason) + ); // Test reverse conversion let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); match roundtrip_event { MessagesStreamEvent::MessageDelta { delta, .. } => { // Note: Some precision may be lost in roundtrip due to mapping differences - assert!(matches!(delta.stop_reason, MessagesStopReason::EndTurn | MessagesStopReason::MaxTokens | MessagesStopReason::ToolUse | MessagesStopReason::StopSequence)); + assert!(matches!( + delta.stop_reason, + MessagesStopReason::EndTurn + | MessagesStopReason::MaxTokens + | MessagesStopReason::ToolUse + | MessagesStopReason::StopSequence + )); } _ => panic!("Expected MessageDelta after roundtrip"), } @@ -1711,7 +1825,8 @@ mod tests { }; // Should convert to Ping when no meaningful content - let anthropic_event: MessagesStreamEvent = openai_resp_with_missing_data.try_into().unwrap(); + let anthropic_event: MessagesStreamEvent = + openai_resp_with_missing_data.try_into().unwrap(); assert!(matches!(anthropic_event, MessagesStreamEvent::Ping)); } diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index a9e8c48e..2789947b 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -1,20 +1,21 @@ //! hermesllm: A library for translating LLM API requests and responses //! between Mistral, Grok, Gemini, and OpenAI-compliant formats. -pub mod providers; pub mod apis; pub mod clients; +pub mod providers; // Re-export important types and traits -pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError}; -pub use providers::response::{ProviderResponseType, ProviderStreamResponseType, ProviderResponse, ProviderStreamResponse, ProviderResponseError, TokenUsage, SseEvent, SseStreamIter}; pub use providers::id::ProviderId; - +pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType}; +pub use providers::response::{ + ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse, + ProviderStreamResponseType, SseEvent, SseStreamIter, TokenUsage, +}; //TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const MESSAGES_PATH: &str = "/v1/messages"; - #[cfg(test)] mod tests { use super::*; @@ -30,48 +31,50 @@ mod tests { #[test] fn test_provider_streaming_response() { // Test streaming response parsing with sample SSE data - let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} + let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} data: [DONE] "#; - use crate::clients::endpoints::SupportedAPIs; - let client_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); - let upstream_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); + use crate::clients::endpoints::SupportedAPIs; + let client_api = + SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); + let upstream_api = + SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); - // Test the new simplified architecture - create SseStreamIter directly - let sse_iter = SseStreamIter::try_from(sse_data.as_bytes()); - assert!(sse_iter.is_ok()); + // Test the new simplified architecture - create SseStreamIter directly + let sse_iter = SseStreamIter::try_from(sse_data.as_bytes()); + assert!(sse_iter.is_ok()); - let mut streaming_iter = sse_iter.unwrap(); + let mut streaming_iter = sse_iter.unwrap(); - // Test that we can iterate over SseEvents - let first_event = streaming_iter.next(); - assert!(first_event.is_some()); + // Test that we can iterate over SseEvents + let first_event = streaming_iter.next(); + assert!(first_event.is_some()); - let sse_event = first_event.unwrap(); + let sse_event = first_event.unwrap(); - // Test SseEvent properties - assert!(!sse_event.is_done()); - assert!(sse_event.data.as_ref().unwrap().contains("Hello")); + // Test SseEvent properties + assert!(!sse_event.is_done()); + assert!(sse_event.data.as_ref().unwrap().contains("Hello")); - // Test that we can parse the event into a provider stream response - let transformed_event = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - if let Err(e) = &transformed_event { - println!("Transform error: {:?}", e); - } - assert!(transformed_event.is_ok()); + // Test that we can parse the event into a provider stream response + let transformed_event = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + if let Err(e) = &transformed_event { + println!("Transform error: {:?}", e); + } + assert!(transformed_event.is_ok()); - let transformed_event = transformed_event.unwrap(); - let provider_response = transformed_event.provider_response(); - assert!(provider_response.is_ok()); + let transformed_event = transformed_event.unwrap(); + let provider_response = transformed_event.provider_response(); + assert!(provider_response.is_ok()); - let stream_response = provider_response.unwrap(); - assert_eq!(stream_response.content_delta(), Some("Hello")); - assert!(!stream_response.is_final()); + let stream_response = provider_response.unwrap(); + assert_eq!(stream_response.content_delta(), Some("Hello")); + assert!(!stream_response.is_final()); - // Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE]) - let final_event = streaming_iter.next(); - assert!(final_event.is_none()); // Should be None because iterator stops at [DONE] + // Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE]) + let final_event = streaming_iter.next(); + assert!(final_event.is_none()); // Should be None because iterator stops at [DONE] } } diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index 46b9cf93..b898d7d7 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -1,6 +1,6 @@ -use std::fmt::Display; +use crate::apis::{AnthropicApi, OpenAIApi}; use crate::clients::endpoints::SupportedAPIs; -use crate::apis::{OpenAIApi, AnthropicApi}; +use std::fmt::Display; /// Provider identifier enum - simple enum for identifying providers #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -50,41 +50,50 @@ impl ProviderId { pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs) -> SupportedAPIs { match (self, client_api) { // Claude/Anthropic providers natively support Anthropic APIs - (ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), - (ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + (ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => { + SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) + } + ( + ProviderId::Anthropic, + SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + ) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), // OpenAI-compatible providers only support OpenAI chat completions - (ProviderId::OpenAI - | ProviderId::Groq - | ProviderId::Mistral - | ProviderId::Deepseek - | ProviderId::Arch - | ProviderId::Gemini - | ProviderId::GitHub - | ProviderId::AzureOpenAI - | ProviderId::XAI - | ProviderId::TogetherAI - | ProviderId::Ollama - | ProviderId::Moonshotai - | ProviderId::Zhipu - | ProviderId::Qwen, - SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + ( + ProviderId::OpenAI + | ProviderId::Groq + | ProviderId::Mistral + | ProviderId::Deepseek + | ProviderId::Arch + | ProviderId::Gemini + | ProviderId::GitHub + | ProviderId::AzureOpenAI + | ProviderId::XAI + | ProviderId::TogetherAI + | ProviderId::Ollama + | ProviderId::Moonshotai + | ProviderId::Zhipu + | ProviderId::Qwen, + SupportedAPIs::AnthropicMessagesAPI(_), + ) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), - (ProviderId::OpenAI - | ProviderId::Groq - | ProviderId::Mistral - | ProviderId::Deepseek - | ProviderId::Arch - | ProviderId::Gemini - | ProviderId::GitHub - | ProviderId::AzureOpenAI - | ProviderId::XAI - | ProviderId::TogetherAI - | ProviderId::Ollama - | ProviderId::Moonshotai - | ProviderId::Zhipu - | ProviderId::Qwen, - SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + ( + ProviderId::OpenAI + | ProviderId::Groq + | ProviderId::Mistral + | ProviderId::Deepseek + | ProviderId::Arch + | ProviderId::Gemini + | ProviderId::GitHub + | ProviderId::AzureOpenAI + | ProviderId::XAI + | ProviderId::TogetherAI + | ProviderId::Ollama + | ProviderId::Moonshotai + | ProviderId::Zhipu + | ProviderId::Qwen, + SupportedAPIs::OpenAIChatCompletions(_), + ) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), } } } diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index 601af955..97b14285 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -8,5 +8,5 @@ pub mod request; pub mod response; pub use id::ProviderId; -pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ; -pub use response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, TokenUsage }; +pub use request::{ProviderRequest, ProviderRequestError, ProviderRequestType}; +pub use response::{ProviderResponse, ProviderResponseType, ProviderStreamResponse, TokenUsage}; diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index 3603edf2..1cee7169 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -1,11 +1,11 @@ -use crate::apis::openai::ChatCompletionsRequest; use crate::apis::anthropic::MessagesRequest; +use crate::apis::openai::ChatCompletionsRequest; use crate::clients::endpoints::SupportedAPIs; use serde_json::Value; +use std::collections::HashMap; use std::error::Error; use std::fmt; -use std::collections::HashMap; #[derive(Clone)] pub enum ProviderRequestType { ChatCompletionsRequest(ChatCompletionsRequest), @@ -103,15 +103,18 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType { // Use SupportedApi to determine the appropriate request type match client_api { SupportedAPIs::OpenAIChatCompletions(_) => { - let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) - } - SupportedAPIs::AnthropicMessagesAPI(_) => { - let messages_request: MessagesRequest = MessagesRequest::try_from(bytes) + let chat_completion_request: ChatCompletionsRequest = + ChatCompletionsRequest::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderRequestType::MessagesRequest(messages_request)) - } + Ok(ProviderRequestType::ChatCompletionsRequest( + chat_completion_request, + )) + } + SupportedAPIs::AnthropicMessagesAPI(_) => { + let messages_request: MessagesRequest = MessagesRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::MessagesRequest(messages_request)) + } } } } @@ -120,40 +123,55 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType { impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType { type Error = ProviderRequestError; - fn try_from((request, upstream_api): (ProviderRequestType, &SupportedAPIs)) -> Result { + fn try_from( + (request, upstream_api): (ProviderRequestType, &SupportedAPIs), + ) -> Result { match (request, upstream_api) { // Same API - no conversion needed, just clone the reference - (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::OpenAIChatCompletions(_)) => { - Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) - } - (ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::AnthropicMessagesAPI(_)) => { - Ok(ProviderRequestType::MessagesRequest(messages_req)) - } + ( + ProviderRequestType::ChatCompletionsRequest(chat_req), + SupportedAPIs::OpenAIChatCompletions(_), + ) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)), + ( + ProviderRequestType::MessagesRequest(messages_req), + SupportedAPIs::AnthropicMessagesAPI(_), + ) => Ok(ProviderRequestType::MessagesRequest(messages_req)), // Cross-API conversion - cloning is necessary for transformation - (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::AnthropicMessagesAPI(_)) => { - let messages_req = MessagesRequest::try_from(chat_req) - .map_err(|e| ProviderRequestError { - message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e), - source: Some(Box::new(e)) + ( + ProviderRequestType::ChatCompletionsRequest(chat_req), + SupportedAPIs::AnthropicMessagesAPI(_), + ) => { + let messages_req = + MessagesRequest::try_from(chat_req).map_err(|e| ProviderRequestError { + message: format!( + "Failed to convert ChatCompletionsRequest to MessagesRequest: {}", + e + ), + source: Some(Box::new(e)), })?; Ok(ProviderRequestType::MessagesRequest(messages_req)) } - (ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::OpenAIChatCompletions(_)) => { - let chat_req = ChatCompletionsRequest::try_from(messages_req) - .map_err(|e| ProviderRequestError { - message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e), - source: Some(Box::new(e)) - })?; + ( + ProviderRequestType::MessagesRequest(messages_req), + SupportedAPIs::OpenAIChatCompletions(_), + ) => { + let chat_req = ChatCompletionsRequest::try_from(messages_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert MessagesRequest to ChatCompletionsRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) } } } } - - /// Error types for provider operations #[derive(Debug)] pub struct ProviderRequestError { @@ -169,19 +187,20 @@ impl fmt::Display for ProviderRequestError { impl Error for ProviderRequestError { fn source(&self) -> Option<&(dyn Error + 'static)> { - self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) + self.source + .as_ref() + .map(|e| e.as_ref() as &(dyn Error + 'static)) } } - #[cfg(test)] mod tests { use super::*; - use crate::clients::endpoints::SupportedAPIs; use crate::apis::anthropic::AnthropicApi::Messages; - use crate::apis::openai::OpenAIApi::ChatCompletions; use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest; - use crate::apis::openai::{ChatCompletionsRequest}; + use crate::apis::openai::ChatCompletionsRequest; + use crate::apis::openai::OpenAIApi::ChatCompletions; + use crate::clients::endpoints::SupportedAPIs; use crate::clients::transformer::ExtractText; use serde_json::json; @@ -202,7 +221,7 @@ mod tests { ProviderRequestType::ChatCompletionsRequest(r) => { assert_eq!(r.model, "gpt-4"); assert_eq!(r.messages.len(), 2); - }, + } _ => panic!("Expected ChatCompletionsRequest variant"), } } @@ -225,7 +244,7 @@ mod tests { ProviderRequestType::MessagesRequest(r) => { assert_eq!(r.model, "claude-3-sonnet"); assert_eq!(r.messages.len(), 1); - }, + } _ => panic!("Expected MessagesRequest variant"), } } @@ -247,7 +266,7 @@ mod tests { ProviderRequestType::ChatCompletionsRequest(r) => { assert_eq!(r.model, "gpt-4"); assert_eq!(r.messages.len(), 2); - }, + } _ => panic!("Expected ChatCompletionsRequest variant"), } } @@ -271,7 +290,7 @@ mod tests { ProviderRequestType::ChatCompletionsRequest(r) => { assert_eq!(r.model, "claude-3-sonnet"); assert_eq!(r.messages.len(), 1); - }, + } _ => panic!("Expected ChatCompletionsRequest variant"), } } @@ -280,13 +299,15 @@ mod tests { fn test_v1_messages_to_v1_chat_completions_roundtrip() { let anthropic_req = AnthropicMessagesRequest { model: "claude-3-sonnet".to_string(), - system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single("You are a helpful assistant".to_string())), - messages: vec![ - crate::apis::anthropic::MessagesMessage { - role: crate::apis::anthropic::MessagesRole::User, - content: crate::apis::anthropic::MessagesMessageContent::Single("Hello!".to_string()), - } - ], + system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single( + "You are a helpful assistant".to_string(), + )), + messages: vec![crate::apis::anthropic::MessagesMessage { + role: crate::apis::anthropic::MessagesRole::User, + content: crate::apis::anthropic::MessagesMessageContent::Single( + "Hello!".to_string(), + ), + }], max_tokens: 128, container: None, mcp_servers: None, @@ -302,16 +323,27 @@ mod tests { metadata: None, }; - let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone()).expect("Anthropic->OpenAI conversion failed"); - let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req).expect("OpenAI->Anthropic conversion failed"); + let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone()) + .expect("Anthropic->OpenAI conversion failed"); + let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req) + .expect("OpenAI->Anthropic conversion failed"); assert_eq!(anthropic_req.model, anthropic_req2.model); // Compare system prompt text if present assert_eq!( - anthropic_req.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None }), - anthropic_req2.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None }) + anthropic_req.system.as_ref().and_then(|s| match s { + crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), + _ => None, + }), + anthropic_req2.system.as_ref().and_then(|s| match s { + crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), + _ => None, + }) + ); + assert_eq!( + anthropic_req.messages[0].role, + anthropic_req2.messages[0].role ); - assert_eq!(anthropic_req.messages[0].role, anthropic_req2.messages[0].role); // Compare message content text if present assert_eq!( anthropic_req.messages[0].content.extract_text(), @@ -320,49 +352,54 @@ mod tests { assert_eq!(anthropic_req.max_tokens, anthropic_req2.max_tokens); } - #[test] - fn test_v1_chat_completions_to_v1_messages_roundtrip() { - use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest; - use crate::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent}; + #[test] + fn test_v1_chat_completions_to_v1_messages_roundtrip() { + use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest; + use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role}; - let openai_req = ChatCompletionsRequest { - model: "gpt-4".to_string(), - messages: vec![ - Message { - role: Role::System, - content: MessageContent::Text("You are a helpful assistant".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, - }, - Message { - role: Role::User, - content: MessageContent::Text("Hello!".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, - } - ], - temperature: Some(0.7), - top_p: Some(1.0), - max_tokens: Some(128), - stream: Some(false), - stop: Some(vec!["\n".to_string()]), - tools: None, - tool_choice: None, - parallel_tool_calls: None, - ..Default::default() - }; + let openai_req = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![ + Message { + role: Role::System, + content: MessageContent::Text("You are a helpful assistant".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }, + Message { + role: Role::User, + content: MessageContent::Text("Hello!".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }, + ], + temperature: Some(0.7), + top_p: Some(1.0), + max_tokens: Some(128), + stream: Some(false), + stop: Some(vec!["\n".to_string()]), + tools: None, + tool_choice: None, + parallel_tool_calls: None, + ..Default::default() + }; - let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone()).expect("OpenAI->Anthropic conversion failed"); - let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req).expect("Anthropic->OpenAI conversion failed"); + let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone()) + .expect("OpenAI->Anthropic conversion failed"); + let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req) + .expect("Anthropic->OpenAI conversion failed"); - assert_eq!(openai_req.model, openai_req2.model); - assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role); - assert_eq!(openai_req.messages[0].content.extract_text(), openai_req2.messages[0].content.extract_text()); - // After roundtrip, deprecated max_tokens should be converted to max_completion_tokens - let original_max_tokens = openai_req.max_completion_tokens.or(openai_req.max_tokens); - let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens); - assert_eq!(original_max_tokens, roundtrip_max_tokens); - } + assert_eq!(openai_req.model, openai_req2.model); + assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role); + assert_eq!( + openai_req.messages[0].content.extract_text(), + openai_req2.messages[0].content.extract_text() + ); + // After roundtrip, deprecated max_tokens should be converted to max_completion_tokens + let original_max_tokens = openai_req.max_completion_tokens.or(openai_req.max_tokens); + let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens); + assert_eq!(original_max_tokens, roundtrip_max_tokens); + } } diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index 6bc4e25f..5f4607df 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -1,15 +1,15 @@ use crate::providers::id::ProviderId; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; +use std::convert::TryFrom; use std::error::Error; use std::fmt; -use std::convert::TryFrom; use std::str::FromStr; +use crate::apis::anthropic::MessagesResponse; +use crate::apis::anthropic::MessagesStreamEvent; use crate::apis::openai::ChatCompletionsResponse; use crate::apis::openai::ChatCompletionsStreamResponse; -use crate::apis::anthropic::MessagesStreamEvent; use crate::clients::endpoints::SupportedAPIs; -use crate::apis::anthropic::MessagesResponse; /// Trait for token usage information pub trait TokenUsage { @@ -38,7 +38,8 @@ pub trait ProviderResponse: Send + Sync { /// Extract token counts for metrics fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { - self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens())) + self.usage() + .map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens())) } } @@ -110,19 +111,19 @@ impl ProviderStreamResponse for ProviderStreamResponseType { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SseEvent { #[serde(rename = "data")] - pub data: Option, // The JSON payload after "data: " + pub data: Option, // The JSON payload after "data: " #[serde(skip_serializing_if = "Option::is_none")] - pub event: Option, // Optional event type (e.g., "message_start", "content_block_delta") + pub event: Option, // Optional event type (e.g., "message_start", "content_block_delta") #[serde(skip_serializing, skip_deserializing)] - pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n" - - #[serde(skip_serializing, skip_deserializing)] - pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n" + pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n" #[serde(skip_serializing, skip_deserializing)] - pub provider_stream_response: Option, // Parsed provider stream response object + pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n" + + #[serde(skip_serializing, skip_deserializing)] + pub provider_stream_response: Option, // Parsed provider stream response object } impl SseEvent { @@ -145,13 +146,13 @@ impl SseEvent { /// Get the parsed provider response if available pub fn provider_response(&self) -> Result<&dyn ProviderStreamResponse, std::io::Error> { - self.provider_stream_response.as_ref() + self.provider_stream_response + .as_ref() .map(|resp| resp as &dyn ProviderStreamResponse) .ok_or_else(|| { std::io::Error::new(std::io::ErrorKind::NotFound, "Provider response not found") }) } - } impl FromStr for SseEvent { @@ -172,7 +173,8 @@ impl FromStr for SseEvent { sse_transform_buffer: line.to_string(), provider_stream_response: None, }) - } else if line.starts_with("event: ") { //used by Anthropic + } else if line.starts_with("event: ") { + //used by Anthropic let event_type = line[7..].to_string(); if event_type.is_empty() { return Err(SseParseError { @@ -207,12 +209,13 @@ impl Into> for SseEvent { } } - // --- Response transformation logic for client API compatibility --- impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { type Error = std::io::Error; - fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result { + fn try_from( + (bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId), + ) -> Result { let upstream_api = provider_id.compatible_api_for_client(client_api); match (&upstream_api, client_api) { (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { @@ -230,8 +233,13 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; // Transform to OpenAI ChatCompletions format using the transformer - let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into() - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; + let chat_resp: ChatCompletionsResponse = + anthropic_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp)) } (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { @@ -239,8 +247,12 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; // Transform to Anthropic Messages format using the transformer - let messages_resp: MessagesResponse = openai_resp.try_into() - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; + let messages_resp: MessagesResponse = openai_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; Ok(ProviderResponseType::MessagesResponse(messages_resp)) } } @@ -251,36 +263,50 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { impl TryFrom<(&[u8], &SupportedAPIs, &SupportedAPIs)> for ProviderStreamResponseType { type Error = Box; - fn try_from((bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedAPIs)) -> Result { + fn try_from( + (bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedAPIs), + ) -> Result { match (upstream_api, client_api) { (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - let resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?; - Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(resp)) + let resp: crate::apis::openai::ChatCompletionsStreamResponse = + serde_json::from_slice(bytes)?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( + resp, + )) } (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - let resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?; + let resp: crate::apis::anthropic::MessagesStreamEvent = + serde_json::from_slice(bytes)?; Ok(ProviderStreamResponseType::MessagesStreamEvent(resp)) } (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?; + let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = + serde_json::from_slice(bytes)?; // Transform to OpenAI ChatCompletions stream format using the transformer - let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = anthropic_resp.try_into()?; - Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(chat_resp)) + let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = + anthropic_resp.try_into()?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( + chat_resp, + )) } (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { // Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion if bytes == b"[DONE]" { return Ok(ProviderStreamResponseType::MessagesStreamEvent( - crate::apis::anthropic::MessagesStreamEvent::MessageStop + crate::apis::anthropic::MessagesStreamEvent::MessageStop, )); } - let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?; + let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = + serde_json::from_slice(bytes)?; // Transform to Anthropic Messages stream format using the transformer - let messages_resp: crate::apis::anthropic::MessagesStreamEvent = openai_resp.try_into()?; - Ok(ProviderStreamResponseType::MessagesStreamEvent(messages_resp)) + let messages_resp: crate::apis::anthropic::MessagesStreamEvent = + openai_resp.try_into()?; + Ok(ProviderStreamResponseType::MessagesStreamEvent( + messages_resp, + )) } } } @@ -290,7 +316,9 @@ impl TryFrom<(&[u8], &SupportedAPIs, &SupportedAPIs)> for ProviderStreamResponse impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent { type Error = Box; - fn try_from((sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedAPIs)) -> Result { + fn try_from( + (sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedAPIs), + ) -> Result { // Create a new transformed event based on the original let mut transformed_event = sse_event; @@ -298,7 +326,8 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent { if transformed_event.data.is_some() { let data_str = transformed_event.data.as_ref().unwrap(); let data_bytes = data_str.as_bytes(); - let transformed_response = ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?; + let transformed_response = + ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?; let transformed_json = serde_json::to_string(&transformed_response)?; transformed_event.sse_transform_buffer = format!("data: {}\n\n", transformed_json); transformed_event.provider_stream_response = Some(transformed_response); @@ -344,7 +373,10 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent { transformed_event.sse_transform_buffer ); } else { - transformed_event.sse_transform_buffer = format!("event: {}\n{}", event_type, transformed_event.sse_transform_buffer); + transformed_event.sse_transform_buffer = format!( + "event: {}\n{}", + event_type, transformed_event.sse_transform_buffer + ); } } // If event_type is None, we just keep the data line as-is without an event line @@ -396,7 +428,10 @@ where I::Item: AsRef, { pub fn new(lines: I) -> Self { - Self { lines, done_seen: false } + Self { + lines, + done_seen: false, + } } } @@ -451,7 +486,6 @@ pub struct ProviderResponseError { pub source: Option>, } - impl fmt::Display for ProviderResponseError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Provider response error: {}", self.message) @@ -460,17 +494,19 @@ impl fmt::Display for ProviderResponseError { impl Error for ProviderResponseError { fn source(&self) -> Option<&(dyn Error + 'static)> { - self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) + self.source + .as_ref() + .map(|e| e.as_ref() as &(dyn Error + 'static)) } } #[cfg(test)] mod tests { use super::*; + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; use crate::clients::endpoints::SupportedAPIs; use crate::providers::id::ProviderId; - use crate::apis::openai::OpenAIApi; - use crate::apis::anthropic::AnthropicApi; use serde_json::json; #[test] @@ -491,13 +527,17 @@ mod tests { "system_fingerprint": null }); let bytes = serde_json::to_vec(&resp).unwrap(); - let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::OpenAI)); + let result = ProviderResponseType::try_from(( + bytes.as_slice(), + &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + &ProviderId::OpenAI, + )); assert!(result.is_ok()); match result.unwrap() { ProviderResponseType::ChatCompletionsResponse(r) => { assert_eq!(r.model, "gpt-4"); assert_eq!(r.choices.len(), 1); - }, + } _ => panic!("Expected ChatCompletionsResponse variant"), } } @@ -516,13 +556,17 @@ mod tests { "usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 } }); let bytes = serde_json::to_vec(&resp).unwrap(); - let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::Anthropic)); + let result = ProviderResponseType::try_from(( + bytes.as_slice(), + &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), + &ProviderId::Anthropic, + )); assert!(result.is_ok()); match result.unwrap() { ProviderResponseType::MessagesResponse(r) => { assert_eq!(r.model, "claude-3-sonnet-20240229"); assert_eq!(r.content.len(), 1); - }, + } _ => panic!("Expected MessagesResponse variant"), } } @@ -546,14 +590,18 @@ mod tests { "usage": { "prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35 } }); let bytes = serde_json::to_vec(&resp).unwrap(); - let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI)); + let result = ProviderResponseType::try_from(( + bytes.as_slice(), + &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), + &ProviderId::OpenAI, + )); assert!(result.is_ok()); match result.unwrap() { ProviderResponseType::MessagesResponse(r) => { assert_eq!(r.model, "gpt-4"); assert_eq!(r.usage.input_tokens, 10); assert_eq!(r.usage.output_tokens, 25); - }, + } _ => panic!("Expected MessagesResponse variant"), } } @@ -584,14 +632,18 @@ mod tests { } }); let bytes = serde_json::to_vec(&resp).unwrap(); - let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Anthropic)); + let result = ProviderResponseType::try_from(( + bytes.as_slice(), + &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + &ProviderId::Anthropic, + )); assert!(result.is_ok()); match result.unwrap() { ProviderResponseType::ChatCompletionsResponse(r) => { assert_eq!(r.model, "claude-3-sonnet-20240229"); assert_eq!(r.usage.prompt_tokens, 10); assert_eq!(r.usage.completion_tokens, 25); - }, + } _ => panic!("Expected ChatCompletionsResponse variant"), } } @@ -603,11 +655,17 @@ mod tests { let event: Result = line.parse(); assert!(event.is_ok()); let event = event.unwrap(); - assert_eq!(event.data, Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string())); + assert_eq!( + event.data, + Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string()) + ); // Test conversion back to line using Display trait let wire_format = event.to_string(); - assert_eq!(wire_format, "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"); + assert_eq!( + wire_format, + "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n" + ); // Test [DONE] marker - should be valid SSE event let done_line = "data: [DONE]"; @@ -639,10 +697,12 @@ mod tests { event: None, raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"} - "#.to_string(), + "# + .to_string(), sse_transform_buffer: r#"data: {"id":"test","object":"chat.completion.chunk"} - "#.to_string(), + "# + .to_string(), provider_stream_response: None, }; @@ -679,7 +739,8 @@ mod tests { data: Some(r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string()), event: Some("content_block_delta".to_string()), raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(), - sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(), + sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"# + .to_string(), provider_stream_response: None, }; assert!(!normal_event.should_skip()); @@ -705,7 +766,7 @@ mod tests { "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out "data: {\"id\": \"msg2\", \"object\": \"chat.completion.chunk\"}".to_string(), "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out - "data: [DONE]".to_string(), // This should end the stream + "data: [DONE]".to_string(), // This should end the stream ]; let mut iter = SseStreamIter::new(test_lines.into_iter()); @@ -773,13 +834,15 @@ mod tests { #[test] fn test_provider_stream_response_event_type() { - use crate::apis::anthropic::{MessagesStreamEvent, MessagesContentDelta}; + use crate::apis::anthropic::{MessagesContentDelta, MessagesStreamEvent}; use crate::apis::openai::ChatCompletionsStreamResponse; // Test Anthropic event type let anthropic_event = MessagesStreamEvent::ContentBlockDelta { index: 0, - delta: MessagesContentDelta::TextDelta { text: "Hello".to_string() }, + delta: MessagesContentDelta::TextDelta { + text: "Hello".to_string(), + }, }; let provider_type = ProviderStreamResponseType::MessagesStreamEvent(anthropic_event); assert_eq!(provider_type.event_type(), Some("content_block_delta")); @@ -806,15 +869,23 @@ mod tests { // Test that [DONE] marker is properly converted to MessageStop in the transformation layer let done_bytes = b"[DONE]"; let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); - let upstream_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::openai::OpenAIApi::ChatCompletions); + let upstream_api = + SupportedAPIs::OpenAIChatCompletions(crate::apis::openai::OpenAIApi::ChatCompletions); - let result = ProviderStreamResponseType::try_from((done_bytes.as_slice(), &client_api, &upstream_api)); + let result = ProviderStreamResponseType::try_from(( + done_bytes.as_slice(), + &client_api, + &upstream_api, + )); assert!(result.is_ok()); if let Ok(ProviderStreamResponseType::MessagesStreamEvent(event)) = result { // Verify it's a MessageStop event assert_eq!(event.event_type(), Some("message_stop")); - assert!(matches!(event, crate::apis::anthropic::MessagesStreamEvent::MessageStop)); + assert!(matches!( + event, + crate::apis::anthropic::MessagesStreamEvent::MessageStop + )); } else { panic!("Expected MessagesStreamEvent::MessageStop"); }