From 6255595b8c3f7c5a8d640e4e59ef1dec5af8df3a Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Wed, 6 Aug 2025 17:47:52 -0700 Subject: [PATCH] this offers transformation support from claude to open ai and back --- crates/Cargo.lock | 292 +++- crates/hermesllm/Cargo.toml | 4 + crates/hermesllm/src/apis/anthropic.rs | 614 +++++-- crates/hermesllm/src/apis/openai.rs | 1164 ++++++++++--- crates/hermesllm/src/clients/endpoints.rs | 130 ++ crates/hermesllm/src/clients/lib.rs | 33 + crates/hermesllm/src/clients/mod.rs | 9 + crates/hermesllm/src/clients/transformer.rs | 1713 +++++++++++++++++++ crates/hermesllm/src/lib.rs | 6 +- crates/hermesllm/src/mod.rs | 2 + 10 files changed, 3551 insertions(+), 416 deletions(-) create mode 100644 crates/hermesllm/src/clients/endpoints.rs create mode 100644 crates/hermesllm/src/clients/lib.rs create mode 100644 crates/hermesllm/src/clients/mod.rs create mode 100644 crates/hermesllm/src/clients/transformer.rs create mode 100644 crates/hermesllm/src/mod.rs diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 1f33b4c2..b9ea3f65 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -104,6 +104,43 @@ version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223" +[[package]] +name = "async-openai" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31acf814d6b499e33ec894bb0fd7ddaf2665b44fbdd42b858d736449271fde0c" +dependencies = [ + "async-openai-macros", + "backoff", + "base64 0.22.1", + "bytes", + "derive_builder", + "eventsource-stream", + "futures", + "rand 0.8.5", + "reqwest", + "reqwest-eventsource", + "secrecy", + "serde", + "serde_json", + "thiserror 2.0.12", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", +] + +[[package]] +name = "async-openai-macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0289cba6d5143bfe8251d57b4a8cac036adf158525a76533a7082ba65ec76398" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "async-trait" version = "0.1.88" @@ -138,6 +175,20 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom 0.2.16", + "instant", + "pin-project-lite", + "rand 0.8.5", + "tokio", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -287,6 +338,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.41" @@ -354,6 +411,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -415,7 +482,7 @@ dependencies = [ "hashbrown 0.14.5", "log", "regalloc2", - "rustc-hash", + "rustc-hash 1.1.0", "smallvec", "target-lexicon", ] @@ -609,6 +676,37 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.101", +] + [[package]] name = "diff" version = "0.1.13" @@ -872,6 +970,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -929,8 +1033,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if 1.0.0", + "js-sys", "libc", "wasi 0.11.0+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -940,9 +1046,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if 1.0.0", + "js-sys", "libc", "r-efi", "wasi 0.14.2+wasi-0.2.4", + "wasm-bindgen", ] [[package]] @@ -1085,10 +1193,12 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" name = "hermesllm" version = "0.1.0" dependencies = [ + "async-openai", "serde", "serde_json", "serde_with", "thiserror 2.0.12", + "tokio", ] [[package]] @@ -1230,7 +1340,7 @@ dependencies = [ "hyper 0.14.32", "log", "rustls 0.21.12", - "rustls-native-certs", + "rustls-native-certs 0.6.3", "tokio", "tokio-rustls 0.24.1", ] @@ -1245,6 +1355,7 @@ dependencies = [ "hyper 1.6.0", "hyper-util", "rustls 0.23.27", + "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", "tokio-rustls 0.26.2", @@ -1483,6 +1594,15 @@ dependencies = [ "serde", ] +[[package]] +name = "instant" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" +dependencies = [ + "cfg-if 1.0.0", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -1660,6 +1780,12 @@ version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "mach2" version = "0.4.2" @@ -1705,6 +1831,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1749,7 +1885,7 @@ dependencies = [ "openssl-probe", "openssl-sys", "schannel", - "security-framework", + "security-framework 2.11.1", "security-framework-sys", "tempfile", ] @@ -2199,6 +2335,61 @@ dependencies = [ "cc", ] +[[package]] +name = "quinn" +version = "0.11.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "626214629cda6781b6dc1d316ba307189c85ba657213ce642d9c77670f8202c8" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash 2.1.1", + "rustls 0.23.27", + "socket2", + "thiserror 2.0.12", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49df843a9161c85bb8aae55f101bc0bac8bcafd637a620d9122fd7e0b2f7422e" +dependencies = [ + "bytes", + "getrandom 0.3.3", + "lru-slab", + "rand 0.9.1", + "ring", + "rustc-hash 2.1.1", + "rustls 0.23.27", + "rustls-pki-types", + "slab", + "thiserror 2.0.12", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcebb1209ee276352ef14ff8732e24cc2b02bbac986cd74a4c81bcb2f9881970" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.59.0", +] + [[package]] name = "quote" version = "1.0.40" @@ -2341,7 +2532,7 @@ checksum = "ad156d539c879b7a24a363a2016d77961786e71f48f2e2fc8302a92abd2429a6" dependencies = [ "hashbrown 0.13.2", "log", - "rustc-hash", + "rustc-hash 1.1.0", "slice-group-by", "smallvec", ] @@ -2414,10 +2605,14 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "once_cell", "percent-encoding", "pin-project-lite", + "quinn", + "rustls 0.23.27", + "rustls-native-certs 0.8.1", "rustls-pki-types", "serde", "serde_json", @@ -2425,6 +2620,7 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", + "tokio-rustls 0.26.2", "tokio-util", "tower 0.5.2", "tower-http", @@ -2436,6 +2632,22 @@ dependencies = [ "web-sys", ] +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite", + "reqwest", + "thiserror 1.0.69", +] + [[package]] name = "ring" version = "0.17.14" @@ -2462,6 +2674,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustix" version = "0.38.44" @@ -2507,6 +2725,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" dependencies = [ "once_cell", + "ring", "rustls-pki-types", "rustls-webpki 0.103.3", "subtle", @@ -2522,7 +2741,19 @@ dependencies = [ "openssl-probe", "rustls-pemfile", "schannel", - "security-framework", + "security-framework 2.11.1", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework 3.3.0", ] [[package]] @@ -2540,6 +2771,7 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" dependencies = [ + "web-time", "zeroize", ] @@ -2628,6 +2860,16 @@ version = "3.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "584e070911c7017da6cb2eb0788d09f43d789029b5877d3e5ecc8acf86ceee21" +[[package]] +name = "secrecy" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a" +dependencies = [ + "serde", + "zeroize", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -2635,7 +2877,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.9.1", - "core-foundation", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80fb1d92c5028aa318b4b8bd7302a5bfcf48be96a37fc6fc790f806b0004ee0c" +dependencies = [ + "bitflags 2.9.1", + "core-foundation 0.10.1", "core-foundation-sys", "libc", "security-framework-sys", @@ -2963,7 +3218,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags 2.9.1", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -3076,7 +3331,7 @@ dependencies = [ "fancy-regex", "lazy_static", "parking_lot", - "rustc-hash", + "rustc-hash 1.1.0", ] [[package]] @@ -3120,6 +3375,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinyvec" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.45.1" @@ -3436,6 +3706,12 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.18" diff --git a/crates/hermesllm/Cargo.toml b/crates/hermesllm/Cargo.toml index c995e85c..cbfea10e 100644 --- a/crates/hermesllm/Cargo.toml +++ b/crates/hermesllm/Cargo.toml @@ -8,3 +8,7 @@ serde = {version = "1.0.219", features = ["derive"]} serde_json = "1.0.140" serde_with = "3.12.0" thiserror = "2.0.12" + +[dev-dependencies] +async-openai = "0.29.0" +tokio = { version = "1.0", features = ["full"] } diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index 7ae085d5..0ffe4e8d 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -384,179 +384,504 @@ impl MessagesStreamEvent { #[cfg(test)] mod tests { use super::*; - use serde_json; + use serde_json::json; #[test] - fn test_anthropic_skip_serializing_none_annotations() { - // Test that skip_serializing_none works correctly for MessagesRequest - let request = MessagesRequest { - model: "claude-3-sonnet-20240229".to_string(), - system: None, // Should be skipped - messages: vec![MessagesMessage { - role: MessagesRole::User, - content: MessagesMessageContent::Single("Hello".to_string()), - }], - max_tokens: 100, - container: None, // Should be skipped - mcp_servers: None, // Should be skipped - service_tier: None, // Should be skipped - thinking: None, // Should be skipped - temperature: None, // Should be skipped - top_p: Some(0.9), // Should be included - top_k: None, // Should be skipped - stream: None, // Should be skipped - stop_sequences: None, // Should be skipped - tools: None, // Should be skipped - tool_choice: None, // Should be skipped - metadata: None, // Should be skipped - }; - - let json = serde_json::to_value(&request).unwrap(); - let obj = json.as_object().unwrap(); - - // Verify that None fields are not present in the JSON - assert!(!obj.contains_key("system")); - assert!(!obj.contains_key("container")); - assert!(!obj.contains_key("mcp_servers")); - assert!(!obj.contains_key("service_tier")); - assert!(!obj.contains_key("thinking")); - assert!(!obj.contains_key("temperature")); - assert!(!obj.contains_key("top_k")); - assert!(!obj.contains_key("stream")); - assert!(!obj.contains_key("stop_sequences")); - assert!(!obj.contains_key("tools")); - assert!(!obj.contains_key("tool_choice")); - assert!(!obj.contains_key("metadata")); - - // Verify that required fields and Some fields are present - assert!(obj.contains_key("model")); - assert!(obj.contains_key("messages")); - assert!(obj.contains_key("max_tokens")); - assert!(obj.contains_key("top_p")); // This was Some(0.9) - } - - #[test] - fn test_anthropic_tool_serialization() { - // Test MessagesTool with skip_serializing_none - let tool = MessagesTool { - name: "get_weather".to_string(), - description: None, // Should be skipped - input_schema: serde_json::json!({ - "type": "object", - "properties": { - "location": {"type": "string"} + fn test_anthropic_required_fields() { + // Create a JSON object with only required fields + let original_json = json!({ + "model": "claude-3-sonnet-20240229", + "messages": [ + { + "role": "user", + "content": "Hello" } - }), - }; + ], + "max_tokens": 100 + }); - let json = serde_json::to_value(&tool).unwrap(); - let obj = json.as_object().unwrap(); + // Deserialize JSON into MessagesRequest + let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap(); - assert!(obj.contains_key("name")); - assert!(obj.contains_key("input_schema")); - assert!(!obj.contains_key("description")); // Should be skipped + // Validate required fields are properly set + assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229"); + assert_eq!(deserialized_request.messages.len(), 1); + assert_eq!(deserialized_request.max_tokens, 100); - // Test with description present - let tool_with_desc = MessagesTool { - name: "get_weather".to_string(), - description: Some("Get weather information".to_string()), - input_schema: serde_json::json!({"type": "object"}), - }; + let message = &deserialized_request.messages[0]; + assert_eq!(message.role, MessagesRole::User); + if let MessagesMessageContent::Single(content) = &message.content { + assert_eq!(content, "Hello"); + } else { + panic!("Expected single content"); + } - let json_with_desc = serde_json::to_value(&tool_with_desc).unwrap(); - let obj_with_desc = json_with_desc.as_object().unwrap(); + // Validate optional fields are None + assert!(deserialized_request.system.is_none()); + assert!(deserialized_request.container.is_none()); + assert!(deserialized_request.mcp_servers.is_none()); + assert!(deserialized_request.service_tier.is_none()); + assert!(deserialized_request.thinking.is_none()); + assert!(deserialized_request.temperature.is_none()); + assert!(deserialized_request.top_p.is_none()); + assert!(deserialized_request.top_k.is_none()); + assert!(deserialized_request.stream.is_none()); + assert!(deserialized_request.stop_sequences.is_none()); + assert!(deserialized_request.tools.is_none()); + assert!(deserialized_request.tool_choice.is_none()); + assert!(deserialized_request.metadata.is_none()); - assert!(obj_with_desc.contains_key("description")); // Should be included + // Serialize back to JSON and compare + let serialized_json = serde_json::to_value(&deserialized_request).unwrap(); + assert_eq!(original_json, serialized_json); } #[test] - fn test_mcp_server_serialization() { - // Test McpServer with skip_serializing_none - let mcp_server = McpServer { - name: "test-server".to_string(), - server_type: McpServerType::Url, - url: "https://example.com/mcp".to_string(), - authorization_token: None, // Should be skipped - tool_configuration: Some(McpToolConfiguration { - allowed_tools: Some(vec!["tool1".to_string(), "tool2".to_string()]), - enabled: None, // Should be skipped - }), - }; + fn test_anthropic_optional_fields() { + // Create a JSON object with optional fields set + let original_json = json!({ + "model": "claude-3-sonnet-20240229", + "messages": [ + { + "role": "user", + "content": "Hello" + } + ], + "max_tokens": 100, + "temperature": 0.7, + "top_p": 0.9, + "system": "You are a helpful assistant", + "service_tier": "auto", + "thinking": { + "enabled": true + }, + "metadata": { + "user_id": "123" + } + }); - let json = serde_json::to_value(&mcp_server).unwrap(); - let obj = json.as_object().unwrap(); + // Deserialize JSON into MessagesRequest + let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap(); - // Verify required fields are present - assert!(obj.contains_key("name")); - assert!(obj.contains_key("type")); - assert!(obj.contains_key("url")); - assert!(obj.contains_key("tool_configuration")); + // Validate required fields + assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229"); + assert_eq!(deserialized_request.messages.len(), 1); + assert_eq!(deserialized_request.max_tokens, 100); - // Verify None fields are not present - assert!(!obj.contains_key("authorization_token")); + // Validate optional fields are properly set + assert!((deserialized_request.temperature.unwrap() - 0.7).abs() < 1e-6); + assert!((deserialized_request.top_p.unwrap() - 0.9).abs() < 1e-6); + assert_eq!(deserialized_request.service_tier, Some(ServiceTier::Auto)); - // Check tool_configuration - let tool_config = obj.get("tool_configuration").unwrap().as_object().unwrap(); - assert!(tool_config.contains_key("allowed_tools")); - assert!(!tool_config.contains_key("enabled")); // Should be skipped + if let Some(MessagesSystemPrompt::Single(system)) = &deserialized_request.system { + assert_eq!(system, "You are a helpful assistant"); + } else { + panic!("Expected single system prompt"); + } - // Verify type serialization - assert_eq!(obj.get("type").unwrap().as_str().unwrap(), "url"); + if let Some(thinking) = &deserialized_request.thinking { + assert_eq!(thinking.enabled, true); + } else { + panic!("Expected thinking config"); + } + + assert!(deserialized_request.metadata.is_some()); + + // Validate fields not in JSON are None + assert!(deserialized_request.container.is_none()); + assert!(deserialized_request.mcp_servers.is_none()); + assert!(deserialized_request.top_k.is_none()); + assert!(deserialized_request.stream.is_none()); + assert!(deserialized_request.stop_sequences.is_none()); + assert!(deserialized_request.tools.is_none()); + assert!(deserialized_request.tool_choice.is_none()); + + // Serialize back to JSON and compare (handle floating point precision) + let serialized_json = serde_json::to_value(&deserialized_request).unwrap(); + + // Compare all fields except floating point ones + assert_eq!(serialized_json["model"], original_json["model"]); + assert_eq!(serialized_json["messages"], original_json["messages"]); + assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]); + assert_eq!(serialized_json["system"], original_json["system"]); + assert_eq!(serialized_json["service_tier"], original_json["service_tier"]); + assert_eq!(serialized_json["thinking"], original_json["thinking"]); + assert_eq!(serialized_json["metadata"], original_json["metadata"]); + + // Handle floating point fields with tolerance + let original_temp = original_json["temperature"].as_f64().unwrap(); + let serialized_temp = serialized_json["temperature"].as_f64().unwrap(); + assert!((original_temp - serialized_temp).abs() < 1e-6); + + let original_top_p = original_json["top_p"].as_f64().unwrap(); + let serialized_top_p = serialized_json["top_p"].as_f64().unwrap(); + assert!((original_top_p - serialized_top_p).abs() < 1e-6); } #[test] - fn test_service_tier_and_thinking_serialization() { - // Test with service_tier and thinking enabled - let request_with_fields = MessagesRequest { - model: "claude-3-sonnet".to_string(), - system: None, - messages: vec![MessagesMessage { - role: MessagesRole::User, - content: MessagesMessageContent::Single("Hello".to_string()), - }], - max_tokens: 100, - container: None, - mcp_servers: None, - service_tier: Some(ServiceTier::Auto), - thinking: Some(ThinkingConfig { enabled: true }), - temperature: None, - top_p: None, - top_k: None, - stream: None, - stop_sequences: None, - tools: None, - tool_choice: None, - metadata: None, - }; + fn test_anthropic_nested_types() { + // Create a comprehensive JSON object with nested types - a MessagesRequest with complex message content and tools + let original_json = json!({ + "model": "claude-3-sonnet-20240229", + "max_tokens": 1000, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What can you see in this image and what's the weather like?" + }, + { + "type": "image", + "source": { + "base64": { + "media_type": "image/jpeg", + "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + } + } + } + ] + }, + { + "role": "assistant", + "content": [ + { + "type": "thinking", + "text": "Let me analyze the image and then check the weather..." + }, + { + "type": "text", + "text": "I can see the image. Let me check the weather for you." + }, + { + "type": "tool_use", + "id": "toolu_weather123", + "name": "get_weather", + "input": { + "location": "San Francisco, CA" + } + } + ] + } + ], + "tools": [ + { + "name": "get_weather", + "description": "Get current weather information for a location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + } + } + ], + "tool_choice": { + "type": "auto" + }, + "system": [ + { + "type": "text", + "text": "You are a helpful assistant that can analyze images and provide weather information." + } + ] + }); - let json = serde_json::to_value(&request_with_fields).unwrap(); - let obj = json.as_object().unwrap(); + // Deserialize JSON into MessagesRequest + let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap(); - // Verify that Some fields are present - assert!(obj.contains_key("service_tier")); - assert!(obj.contains_key("thinking")); + // Validate top-level fields + assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229"); + assert_eq!(deserialized_request.max_tokens, 1000); + assert_eq!(deserialized_request.messages.len(), 2); - // Verify service_tier serialization - assert_eq!(obj.get("service_tier").unwrap().as_str().unwrap(), "auto"); + // Validate first message (user with text and image content) + let user_message = &deserialized_request.messages[0]; + assert_eq!(user_message.role, MessagesRole::User); + if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content { + assert_eq!(content_blocks.len(), 2); - // Verify thinking serialization - let thinking = obj.get("thinking").unwrap().as_object().unwrap(); - assert!(thinking.contains_key("enabled")); - assert_eq!(thinking.get("enabled").unwrap().as_bool().unwrap(), true); + // Validate text content block + if let MessagesContentBlock::Text { text } = &content_blocks[0] { + assert_eq!(text, "What can you see in this image and what's the weather like?"); + } else { + panic!("Expected text content block"); + } + + // Validate image content block + if let MessagesContentBlock::Image { ref source } = content_blocks[1] { + if let MessagesImageSource::Base64 { media_type, data } = source { + assert_eq!(media_type, "image/jpeg"); + assert_eq!(data, "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="); + } else { + panic!("Expected base64 image source"); + } + } else { + panic!("Expected image content block"); + } + } else { + panic!("Expected content blocks for user message"); + } + + // Validate second message (assistant with thinking, text, and tool use) + let assistant_message = &deserialized_request.messages[1]; + assert_eq!(assistant_message.role, MessagesRole::Assistant); + if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content { + assert_eq!(content_blocks.len(), 3); + + // Validate thinking content block + if let MessagesContentBlock::Thinking { text } = &content_blocks[0] { + assert_eq!(text, "Let me analyze the image and then check the weather..."); + } else { + panic!("Expected thinking content block"); + } + + // Validate text content block + if let MessagesContentBlock::Text { text } = &content_blocks[1] { + assert_eq!(text, "I can see the image. Let me check the weather for you."); + } else { + panic!("Expected text content block"); + } + + // Validate tool use content block + if let MessagesContentBlock::ToolUse { ref id, ref name, ref input } = content_blocks[2] { + assert_eq!(id, "toolu_weather123"); + assert_eq!(name, "get_weather"); + assert_eq!(input["location"], "San Francisco, CA"); + } else { + panic!("Expected tool use content block"); + } + } else { + panic!("Expected content blocks for assistant message"); + } + + // Validate tools array + assert!(deserialized_request.tools.is_some()); + let tools = deserialized_request.tools.as_ref().unwrap(); + assert_eq!(tools.len(), 1); + + let tool = &tools[0]; + assert_eq!(tool.name, "get_weather"); + assert_eq!(tool.description, Some("Get current weather information for a location".to_string())); + assert_eq!(tool.input_schema["type"], "object"); + assert!(tool.input_schema["properties"]["location"].is_object()); + + // Validate tool choice + assert!(deserialized_request.tool_choice.is_some()); + let tool_choice = deserialized_request.tool_choice.as_ref().unwrap(); + assert_eq!(tool_choice.kind, MessagesToolChoiceType::Auto); + assert!(tool_choice.name.is_none()); + + // Validate system prompt with content blocks + assert!(deserialized_request.system.is_some()); + if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system { + assert_eq!(system_blocks.len(), 1); + if let MessagesContentBlock::Text { text } = &system_blocks[0] { + assert_eq!(text, "You are a helpful assistant that can analyze images and provide weather information."); + } else { + panic!("Expected text content block in system prompt"); + } + } else { + panic!("Expected system prompt with content blocks"); + } + + // Serialize back to JSON and compare + let serialized_json = serde_json::to_value(&deserialized_request).unwrap(); + assert_eq!(original_json, serialized_json); + } + + #[test] + fn test_anthropic_mcp_server_configuration() { + // Test MCP Server configuration with JSON-first approach + let mcp_server_json = json!({ + "name": "test-server", + "type": "url", + "url": "https://example.com/mcp", + "authorization_token": "secret-token", + "tool_configuration": { + "allowed_tools": ["tool1", "tool2"], + "enabled": true + } + }); + + let deserialized_mcp: McpServer = serde_json::from_value(mcp_server_json.clone()).unwrap(); + assert_eq!(deserialized_mcp.name, "test-server"); + assert_eq!(deserialized_mcp.server_type, McpServerType::Url); + assert_eq!(deserialized_mcp.url, "https://example.com/mcp"); + assert_eq!(deserialized_mcp.authorization_token, Some("secret-token".to_string())); + + if let Some(tool_config) = &deserialized_mcp.tool_configuration { + assert_eq!(tool_config.allowed_tools, Some(vec!["tool1".to_string(), "tool2".to_string()])); + assert_eq!(tool_config.enabled, Some(true)); + } else { + panic!("Expected tool configuration"); + } + + let serialized_mcp_json = serde_json::to_value(&deserialized_mcp).unwrap(); + assert_eq!(mcp_server_json, serialized_mcp_json); + + // Test MCP Server with minimal configuration (optional fields as None) + let minimal_mcp_json = json!({ + "name": "minimal-server", + "type": "url", + "url": "https://minimal.com/mcp" + }); + + let deserialized_minimal: McpServer = serde_json::from_value(minimal_mcp_json.clone()).unwrap(); + assert_eq!(deserialized_minimal.name, "minimal-server"); + assert_eq!(deserialized_minimal.server_type, McpServerType::Url); + assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp"); + assert!(deserialized_minimal.authorization_token.is_none()); + assert!(deserialized_minimal.tool_configuration.is_none()); + + let serialized_minimal_json = serde_json::to_value(&deserialized_minimal).unwrap(); + assert_eq!(minimal_mcp_json, serialized_minimal_json); + } + + #[test] + fn test_anthropic_response_types() { + // Test MessagesResponse deserialization + let response_json = json!({ + "id": "msg_01ABC123", + "type": "message", + "role": "assistant", + "content": [ + { + "type": "text", + "text": "Hello! How can I help you today?" + } + ], + "model": "claude-3-sonnet-20240229", + "stop_reason": "end_turn", + "usage": { + "input_tokens": 10, + "output_tokens": 25, + "cache_creation_input_tokens": 5, + "cache_read_input_tokens": 3 + } + }); + + let deserialized_response: MessagesResponse = serde_json::from_value(response_json.clone()).unwrap(); + assert_eq!(deserialized_response.id, "msg_01ABC123"); + assert_eq!(deserialized_response.obj_type, "message"); + assert_eq!(deserialized_response.role, MessagesRole::Assistant); + assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229"); + assert_eq!(deserialized_response.stop_reason, MessagesStopReason::EndTurn); + assert!(deserialized_response.stop_sequence.is_none()); + assert!(deserialized_response.container.is_none()); + + // Check content + assert_eq!(deserialized_response.content.len(), 1); + if let MessagesContentBlock::Text { text } = &deserialized_response.content[0] { + assert_eq!(text, "Hello! How can I help you today?"); + } else { + panic!("Expected text content block"); + } + + // Check usage + assert_eq!(deserialized_response.usage.input_tokens, 10); + assert_eq!(deserialized_response.usage.output_tokens, 25); + assert_eq!(deserialized_response.usage.cache_creation_input_tokens, Some(5)); + assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3)); + + let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap(); + assert_eq!(response_json, serialized_response_json); + + // Test streaming event + let stream_event_json = json!({ + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " How" + } + }); + + let deserialized_event: MessagesStreamEvent = serde_json::from_value(stream_event_json.clone()).unwrap(); + if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event { + assert_eq!(index, 0); + if let MessagesContentDelta::TextDelta { text } = delta { + assert_eq!(text, " How"); + } else { + panic!("Expected text delta"); + } + } else { + panic!("Expected content block delta event"); + } + + let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap(); + assert_eq!(stream_event_json, serialized_event_json); + } + + #[test] + fn test_anthropic_tool_use_content() { + // Test tool use and tool result content blocks + let tool_use_json = json!({ + "type": "tool_use", + "id": "toolu_01ABC123", + "name": "get_weather", + "input": { + "location": "San Francisco, CA" + } + }); + + let deserialized_tool_use: MessagesContentBlock = serde_json::from_value(tool_use_json.clone()).unwrap(); + if let MessagesContentBlock::ToolUse { ref id, ref name, ref input } = deserialized_tool_use { + assert_eq!(id, "toolu_01ABC123"); + assert_eq!(name, "get_weather"); + assert_eq!(input["location"], "San Francisco, CA"); + } else { + panic!("Expected tool use content block"); + } + + let serialized_tool_use_json = serde_json::to_value(&deserialized_tool_use).unwrap(); + assert_eq!(tool_use_json, serialized_tool_use_json); + + // Test tool result content block + let tool_result_json = json!({ + "type": "tool_result", + "tool_use_id": "toolu_01ABC123", + "content": [ + { + "type": "text", + "text": "The weather in San Francisco is sunny, 72°F" + } + ] + }); + + let deserialized_tool_result: MessagesContentBlock = serde_json::from_value(tool_result_json.clone()).unwrap(); + if let MessagesContentBlock::ToolResult { ref tool_use_id, ref is_error, ref content } = deserialized_tool_result { + assert_eq!(tool_use_id, "toolu_01ABC123"); + assert!(is_error.is_none()); + assert_eq!(content.len(), 1); + if let MessagesContentBlock::Text { text } = &content[0] { + assert_eq!(text, "The weather in San Francisco is sunny, 72°F"); + } else { + panic!("Expected text content in tool result"); + } + } else { + panic!("Expected tool result content block"); + } + + let serialized_tool_result_json = serde_json::to_value(&deserialized_tool_result).unwrap(); + assert_eq!(tool_result_json, serialized_tool_result_json); } #[test] fn test_anthropic_api_provider_trait_implementation() { - use super::ApiDefinition; - // Test that AnthropicApi implements ApiDefinition trait correctly let api = AnthropicApi::Messages; // Test trait methods - assert_eq!(ApiDefinition::endpoint(&api), "/v1/messages"); - assert!(ApiDefinition::supports_streaming(&api)); - assert!(ApiDefinition::supports_tools(&api)); - assert!(ApiDefinition::supports_vision(&api)); + assert_eq!(api.endpoint(), "/v1/messages"); + assert!(api.supports_streaming()); + assert!(api.supports_tools()); + assert!(api.supports_vision()); // Test from_endpoint trait method let found_api = AnthropicApi::from_endpoint("/v1/messages"); @@ -564,5 +889,10 @@ mod tests { let not_found = AnthropicApi::from_endpoint("/v1/unknown"); assert_eq!(not_found, None); + + // Test all_variants + let all_variants = AnthropicApi::all_variants(); + assert_eq!(all_variants.len(), 1); + assert_eq!(all_variants[0], AnthropicApi::Messages); } } diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 5fd13b98..ea2d8cd9 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -64,7 +64,7 @@ impl ApiDefinition for OpenAIApi { pub struct ChatCompletionsRequest { pub messages: Vec, pub model: String, - // pub auduio: Option