From f10e0fcecebb488090d5dabb6d6af1f17959be1b Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 3 Jun 2025 15:00:57 -0700 Subject: [PATCH] add more changes --- crates/Cargo.lock | 224 +++++++++++++++- crates/brightstaff/Cargo.toml | 1 + .../src/handlers/chat_completions.rs | 8 +- crates/brightstaff/src/router/llm_router.rs | 7 +- crates/brightstaff/src/router/router_model.rs | 2 +- .../brightstaff/src/router/router_model_v1.rs | 21 +- crates/hermesllm/Cargo.toml | 1 + crates/hermesllm/src/lib.rs | 25 +- .../hermesllm/src/providers/common_types.rs | 44 ---- .../hermesllm/src/providers/deepseek/types.rs | 16 +- crates/hermesllm/src/providers/groq/types.rs | 15 +- crates/hermesllm/src/providers/mod.rs | 13 +- crates/hermesllm/src/providers/openai/mod.rs | 6 +- .../hermesllm/src/providers/openai/types.rs | 242 +++++++++++++++--- 14 files changed, 486 insertions(+), 139 deletions(-) delete mode 100644 crates/hermesllm/src/providers/common_types.rs diff --git a/crates/Cargo.lock b/crates/Cargo.lock index fe977e6f..ba5d3796 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -68,6 +68,21 @@ version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "ansi_term" version = "0.12.1" @@ -196,6 +211,7 @@ dependencies = [ "eventsource-stream", "futures", "futures-util", + "hermesllm", "http-body 1.0.1", "http-body-util", "hyper 1.6.0", @@ -276,7 +292,11 @@ version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" dependencies = [ + "android-tzdata", + "iana-time-zone", "num-traits", + "serde", + "windows-link", ] [[package]] @@ -288,7 +308,7 @@ dependencies = [ "ansi_term", "atty", "bitflags 1.3.2", - "strsim", + "strsim 0.8.0", "textwrap", "unicode-width", "vec_map", @@ -521,6 +541,41 @@ dependencies = [ "typenum", ] +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.87", +] + +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.87", +] + [[package]] name = "debugid" version = "0.8.0" @@ -530,6 +585,16 @@ dependencies = [ "uuid", ] +[[package]] +name = "deranged" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" +dependencies = [ + "powerfmt", + "serde", +] + [[package]] name = "derivative" version = "2.2.0" @@ -1013,6 +1078,7 @@ dependencies = [ "common", "serde", "serde_json", + "serde_with", "thiserror 2.0.12", ] @@ -1238,6 +1304,30 @@ dependencies = [ "tracing", ] +[[package]] +name = "iana-time-zone" +version = "0.1.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icu_collections" version = "1.5.0" @@ -1362,6 +1452,12 @@ version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "1.0.3" @@ -1391,6 +1487,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown 0.12.3", + "serde", ] [[package]] @@ -1677,6 +1774,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-traits" version = "0.2.19" @@ -1935,6 +2038,12 @@ dependencies = [ "serde", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -2542,6 +2651,36 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_with" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6b6f7f2fcb69f747921f79f3926bd1e203fce4fef62c268dd3abfb6d86029aa" +dependencies = [ + "base64 0.22.1", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.6.0", + "serde", + "serde_derive", + "serde_json", + "serde_with_macros", + "time", +] + +[[package]] +name = "serde_with_macros" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d00caa5193a3c8362ac2b73be6b9e768aa5a4b2f721d8f4b339600c3cb51f8e" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "serde_yaml" version = "0.9.34+deprecated" @@ -2676,6 +2815,12 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "structopt" version = "0.3.26" @@ -2871,6 +3016,37 @@ dependencies = [ "rustc-hash", ] +[[package]] +name = "time" +version = "0.3.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" + +[[package]] +name = "time-macros" +version = "0.2.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tinystr" version = "0.7.6" @@ -3775,6 +3951,41 @@ dependencies = [ "wasmtime-environ", ] +[[package]] +name = "windows-core" +version = "0.61.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings 0.4.2", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "windows-link" version = "0.1.1" @@ -3788,7 +3999,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" dependencies = [ "windows-result", - "windows-strings", + "windows-strings 0.3.1", "windows-targets 0.53.0", ] @@ -3810,6 +4021,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-strings" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/crates/brightstaff/Cargo.toml b/crates/brightstaff/Cargo.toml index 3f51b6a0..b8c2582c 100644 --- a/crates/brightstaff/Cargo.toml +++ b/crates/brightstaff/Cargo.toml @@ -10,6 +10,7 @@ eventsource-client = "0.15.0" eventsource-stream = "0.2.3" futures = "0.3.31" futures-util = "0.3.31" +hermesllm = { version = "0.1.0", path = "../hermesllm" } http-body = "1.0.1" http-body-util = "0.1.3" hyper = { version = "1.6.0", features = ["full"] } diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index 0a5bd25d..413dfddf 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -1,14 +1,13 @@ use std::sync::Arc; use bytes::Bytes; -use common::api::open_ai::ChatCompletionsRequest; use common::consts::ARCH_PROVIDER_HINT_HEADER; +use hermesllm::providers::openai::types::ChatCompletionsRequest; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full, StreamBody}; use hyper::body::Frame; use hyper::header::{self}; use hyper::{Request, Response, StatusCode}; -use serde_json::Value; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt; @@ -32,13 +31,12 @@ pub async fn chat_completions( let chat_request_bytes = request.collect().await?.to_bytes(); let chat_completion_request: ChatCompletionsRequest = - match serde_json::from_slice(&chat_request_bytes) { + match ChatCompletionsRequest::try_from(chat_request_bytes.as_ref()) { Ok(request) => request, Err(err) => { - let v: Value = serde_json::from_slice(&chat_request_bytes).unwrap(); + warn!("arch-router request body string: {}", String::from_utf8_lossy(&chat_request_bytes)); let err_msg = format!("Failed to parse request body: {}", err); warn!("{}", err_msg); - warn!("arch-router request body: {}", v.to_string()); let mut bad_request = Response::new(full(err_msg)); *bad_request.status_mut() = StatusCode::BAD_REQUEST; return Ok(bad_request); diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs index d4158388..4a510caa 100644 --- a/crates/brightstaff/src/router/llm_router.rs +++ b/crates/brightstaff/src/router/llm_router.rs @@ -1,10 +1,10 @@ use std::sync::Arc; use common::{ - api::open_ai::{ChatCompletionsResponse, ContentType, Message}, configuration::{LlmProvider, LlmRoute}, consts::ARCH_PROVIDER_HINT_HEADER, }; +use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message}; use hyper::header; use thiserror::Error; use tracing::{debug, info, warn}; @@ -136,6 +136,11 @@ impl RouterService { } }; + if chat_completion_response.choices.is_empty() { + warn!("No choices in router response: {}", body); + return Ok(None); + } + if let Some(ContentType::Text(content)) = &chat_completion_response.choices[0].message.content { diff --git a/crates/brightstaff/src/router/router_model.rs b/crates/brightstaff/src/router/router_model.rs index 6e591e4c..c2ed43c9 100644 --- a/crates/brightstaff/src/router/router_model.rs +++ b/crates/brightstaff/src/router/router_model.rs @@ -1,4 +1,4 @@ -use common::api::open_ai::{ChatCompletionsRequest, Message}; +use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message}; use thiserror::Error; #[derive(Debug, Error)] diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index bc69b475..f32a19b7 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -1,8 +1,8 @@ use common::{ - api::open_ai::{ChatCompletionsRequest, ContentType, Message}, configuration::LlmRoute, consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE}, }; +use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message}; use serde::{Deserialize, Serialize}; use tracing::{debug, warn}; @@ -121,11 +121,13 @@ impl RouterModel for RouterModelV1 { .iter() .rev() .map(|message| { - Message::new( - message.role.clone(), + Message { + role: message.role.clone(), // we can unwrap here because we have already filtered out messages without content - message.content.as_ref().unwrap().to_string(), - ) + content: Some(ContentType::Text( + message.content.as_ref().unwrap().to_string(), + )), + } }) .collect::>(); @@ -141,14 +143,9 @@ impl RouterModel for RouterModelV1 { messages: vec![Message { content: Some(ContentType::Text(messages_content)), role: USER_ROLE.to_string(), - model: None, - tool_calls: None, - tool_call_id: None, + }], - tools: None, - stream: false, - stream_options: None, - metadata: None, + ..Default::default() } } diff --git a/crates/hermesllm/Cargo.toml b/crates/hermesllm/Cargo.toml index b49c8b97..5393a9ad 100644 --- a/crates/hermesllm/Cargo.toml +++ b/crates/hermesllm/Cargo.toml @@ -7,4 +7,5 @@ edition = "2021" common = { version = "0.1.0", path = "../common" } serde = "1.0.219" serde_json = "1.0.140" +serde_with = "3.12.0" thiserror = "2.0.12" diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 602336c9..192f8090 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -5,12 +5,13 @@ pub mod providers; #[cfg(test)] mod tests { - use crate::providers::openai::types::OpenAIRequest; + use crate::providers::openai::types::ChatCompletionsRequest; + + #[test] fn openai_builder() { - let request = OpenAIRequest::builder() - .model("gpt-3.5-turbo") + let request = ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![]) .temperature(0.7) .top_p(0.9) .n(1) @@ -22,14 +23,14 @@ mod tests { .build() .expect("Failed to build OpenAIRequest"); - assert_eq!(request.base.model, "gpt-3.5-turbo"); - assert_eq!(request.base.temperature, Some(0.7)); - assert_eq!(request.base.top_p, Some(0.9)); - assert_eq!(request.base.n, Some(1)); - assert_eq!(request.base.max_tokens, Some(100)); - assert_eq!(request.base.stream, Some(false)); - assert_eq!(request.base.stop, Some(vec!["\n".to_string()])); - assert_eq!(request.base.presence_penalty, Some(0.0)); - assert_eq!(request.base.frequency_penalty, Some(0.0)); + assert_eq!(request.model, "gpt-3.5-turbo"); + assert_eq!(request.temperature, Some(0.7)); + assert_eq!(request.top_p, Some(0.9)); + assert_eq!(request.n, Some(1)); + assert_eq!(request.max_tokens, Some(100)); + assert_eq!(request.stream, Some(false)); + assert_eq!(request.stop, Some(vec!["\n".to_string()])); + assert_eq!(request.presence_penalty, Some(0.0)); + assert_eq!(request.frequency_penalty, Some(0.0)); } } diff --git a/crates/hermesllm/src/providers/common_types.rs b/crates/hermesllm/src/providers/common_types.rs deleted file mode 100644 index bb222909..00000000 --- a/crates/hermesllm/src/providers/common_types.rs +++ /dev/null @@ -1,44 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Message { - pub role: String, - pub content: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatRequestBase { - pub model: String, - pub messages: Option>, - pub temperature: Option, - pub top_p: Option, - pub n: Option, - pub max_tokens: Option, - pub stream: Option, - pub stop: Option>, - pub presence_penalty: Option, - pub frequency_penalty: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatResponseBase { - pub id: String, - pub object: String, - pub created: u64, - pub choices: Vec, - pub usage: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Choice { - pub index: u32, - pub message: Message, - pub finish_reason: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} diff --git a/crates/hermesllm/src/providers/deepseek/types.rs b/crates/hermesllm/src/providers/deepseek/types.rs index 62b17388..e5585818 100644 --- a/crates/hermesllm/src/providers/deepseek/types.rs +++ b/crates/hermesllm/src/providers/deepseek/types.rs @@ -1,17 +1,19 @@ -use serde::{Deserialize, Serialize}; -use crate::providers::common_types::{ChatRequestBase, ChatResponseBase}; +use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse}; +pub use crate::providers::openai::types::{Choice, Message, Usage}; +use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; + +#[skip_serializing_none] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DeepSeekRequest { #[serde(flatten)] - pub base: ChatRequestBase, + pub base: ChatCompletionsRequest, } +#[skip_serializing_none] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DeepSeekResponse { #[serde(flatten)] - pub base: ChatResponseBase, + pub base: ChatCompletionsResponse, } - -// Re-export for convenience -pub use crate::providers::common_types::{Message, Choice, Usage}; diff --git a/crates/hermesllm/src/providers/groq/types.rs b/crates/hermesllm/src/providers/groq/types.rs index 6f88231a..67b7b47b 100644 --- a/crates/hermesllm/src/providers/groq/types.rs +++ b/crates/hermesllm/src/providers/groq/types.rs @@ -1,16 +1,19 @@ -use serde::{Deserialize, Serialize}; -use crate::providers::common_types::{ChatRequestBase, ChatResponseBase}; +use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse}; +pub use crate::providers::openai::types::{Choice, Message, Usage}; +use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; + +#[skip_serializing_none] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GroqRequest { #[serde(flatten)] - pub base: ChatRequestBase, + pub base: ChatCompletionsRequest, } +#[skip_serializing_none] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GroqResponse { #[serde(flatten)] - pub base: ChatResponseBase, + pub base: ChatCompletionsResponse, } - -pub use crate::providers::common_types::{Message, Choice, Usage}; diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index 8ffc9f57..8ceda63a 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -1,12 +1,3 @@ -pub mod openai; -pub mod groq; pub mod deepseek; -pub mod common_types; - -/// Supported LLM providers. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Provider { - Groq, - OpenAI, - DeepSeek, -} +pub mod groq; +pub mod openai; diff --git a/crates/hermesllm/src/providers/openai/mod.rs b/crates/hermesllm/src/providers/openai/mod.rs index 4060b9bf..b4b7df04 100644 --- a/crates/hermesllm/src/providers/openai/mod.rs +++ b/crates/hermesllm/src/providers/openai/mod.rs @@ -2,7 +2,7 @@ pub mod types; use thiserror::Error; -use crate::providers::openai::types::{OpenAIRequest, OpenAIResponse}; +use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse}; #[derive(Debug, Error)] pub enum OpenAIError { @@ -12,14 +12,14 @@ pub enum OpenAIError { type Result = std::result::Result; -impl TryFrom<&[u8]> for OpenAIRequest { +impl TryFrom<&[u8]> for ChatCompletionsRequest { type Error = OpenAIError; fn try_from(bytes: &[u8]) -> Result { serde_json::from_slice(bytes).map_err(OpenAIError::from) } } -impl TryFrom<&[u8]> for OpenAIResponse { +impl TryFrom<&[u8]> for ChatCompletionsResponse { type Error = OpenAIError; fn try_from(bytes: &[u8]) -> Result { serde_json::from_slice(bytes).map_err(OpenAIError::from) diff --git a/crates/hermesllm/src/providers/openai/types.rs b/crates/hermesllm/src/providers/openai/types.rs index 43893fa4..880d31d1 100644 --- a/crates/hermesllm/src/providers/openai/types.rs +++ b/crates/hermesllm/src/providers/openai/types.rs @@ -1,15 +1,124 @@ -use serde::{Deserialize, Serialize}; -use crate::providers::common_types::{ChatRequestBase, ChatResponseBase}; +use std::fmt::Display; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct OpenAIRequest { - #[serde(flatten)] - pub base: ChatRequestBase, +use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum MultiPartContentType { + #[serde(rename = "text")] + Text, + #[serde(rename = "image_url")] + ImageUrl, } -#[derive(Debug, Default, Clone)] + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct MultiPartContent { + pub text: Option, + #[serde(rename = "type")] + pub content_type: MultiPartContentType, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(untagged)] +pub enum ContentType { + Text(String), + MultiPart(Vec), +} + +impl Display for ContentType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ContentType::Text(text) => write!(f, "{}", text), + ContentType::MultiPart(multi_part) => { + let text_parts: Vec = multi_part + .iter() + .filter_map(|part| { + if part.content_type == MultiPartContentType::Text { + part.text.clone() + } else if part.content_type == MultiPartContentType::ImageUrl { + // skip image URLs or their data in text representation + None + } else { + panic!("Unsupported content type: {:?}", part.content_type); + } + }) + .collect(); + let combined_text = text_parts.join("\n"); + write!(f, "{}", combined_text) + } + } + } +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: String, + pub content: Option, +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionsRequest { + pub model: String, + pub messages: Vec, + pub temperature: Option, + pub top_p: Option, + pub n: Option, + pub max_tokens: Option, + pub stream: Option, + pub stop: Option>, + pub presence_penalty: Option, + pub frequency_penalty: Option, +} + +impl Default for ChatCompletionsRequest { + fn default() -> Self { + ChatCompletionsRequest { + model: String::new(), + messages: Vec::new(), + temperature: None, + top_p: None, + n: None, + max_tokens: None, + stream: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + } + } +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionsResponse { + pub id: String, + pub object: String, + pub created: u64, + pub choices: Vec, + pub usage: Option, +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Choice { + pub index: u32, + pub message: Message, + pub finish_reason: Option, +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Clone)] pub struct OpenAIRequestBuilder { - model: Option, - messages: Option>, + model: String, + messages: Vec, temperature: Option, top_p: Option, n: Option, @@ -21,18 +130,19 @@ pub struct OpenAIRequestBuilder { } impl OpenAIRequestBuilder { - pub fn new() -> Self { - Self::default() - } - - pub fn model(mut self, model: impl Into) -> Self { - self.model = Some(model.into()); - self - } - - pub fn messages(mut self, messages: Vec) -> Self { - self.messages = Some(messages); - self + pub fn new(model: impl Into, messages: Vec) -> Self { + Self { + model: model.into(), + messages, + temperature: None, + top_p: None, + n: None, + max_tokens: None, + stream: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + } } pub fn temperature(mut self, temperature: f32) -> Self { @@ -75,10 +185,9 @@ impl OpenAIRequestBuilder { self } - pub fn build(self) -> Result { - let model = self.model.ok_or("model is required")?; - let base = crate::providers::common_types::ChatRequestBase { - model, + pub fn build(self) -> Result { + let request = ChatCompletionsRequest { + model: self.model, messages: self.messages, temperature: self.temperature, top_p: self.top_p, @@ -89,20 +198,83 @@ impl OpenAIRequestBuilder { presence_penalty: self.presence_penalty, frequency_penalty: self.frequency_penalty, }; - Ok(OpenAIRequest { base }) + Ok(request) } } -impl OpenAIRequest { - pub fn builder() -> OpenAIRequestBuilder { - OpenAIRequestBuilder::new() +impl ChatCompletionsRequest { + pub fn builder(model: impl Into, messages: Vec) -> OpenAIRequestBuilder { + OpenAIRequestBuilder::new(model, messages) } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct OpenAIResponse { - #[serde(flatten)] - pub base: ChatResponseBase, -} +#[cfg(test)] +mod tests { + use super::*; -pub use crate::providers::common_types::{Message, Choice, Usage}; + #[test] + fn test_content_type_display() { + let text_content = ContentType::Text("Hello, world!".to_string()); + assert_eq!(text_content.to_string(), "Hello, world!"); + + let multi_part_content = ContentType::MultiPart(vec![ + MultiPartContent { + text: Some("This is a text part.".to_string()), + content_type: MultiPartContentType::Text, + }, + MultiPartContent { + text: Some("https://example.com/image.png".to_string()), + content_type: MultiPartContentType::ImageUrl, + }, + ]); + assert_eq!(multi_part_content.to_string(), "This is a text part."); + } + + #[test] + fn test_chat_completions_request_text_type_array() { + const CHAT_COMPLETIONS_REQUEST: &str = r#" + { + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What city do you want to know the weather for?" + }, + { + "type": "text", + "text": "hello world" + } + ] + } + ] + } + "#; + + let chat_completions_request: ChatCompletionsRequest = + serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap(); + assert_eq!(chat_completions_request.model, "gpt-3.5-turbo"); + if let Some(ContentType::MultiPart(multi_part_content)) = + chat_completions_request.messages[0].content.as_ref() + { + assert_eq!(multi_part_content.len(), 2); + assert_eq!( + multi_part_content[0].content_type, + MultiPartContentType::Text + ); + assert_eq!( + multi_part_content[0].text, + Some("What city do you want to know the weather for?".to_string()) + ); + assert_eq!( + multi_part_content[1].content_type, + MultiPartContentType::Text + ); + assert_eq!(multi_part_content[1].text, Some("hello world".to_string())); + } else { + panic!("Expected MultiPartContent"); + } + } +}