diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 96ff21da..602336c9 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -9,7 +9,8 @@ mod tests { #[test] fn openai_builder() { - let request = OpenAIRequest::builder("gpt-3.5-turbo") + let request = OpenAIRequest::builder() + .model("gpt-3.5-turbo") .temperature(0.7) .top_p(0.9) .n(1) @@ -18,16 +19,17 @@ mod tests { .stop(vec!["\n".to_string()]) .presence_penalty(0.0) .frequency_penalty(0.0) - .build(); + .build() + .expect("Failed to build OpenAIRequest"); - assert_eq!(request.model, "gpt-3.5-turbo"); - assert_eq!(request.temperature, Some(0.7)); - assert_eq!(request.top_p, Some(0.9)); - assert_eq!(request.n, Some(1)); - assert_eq!(request.max_tokens, Some(100)); - assert_eq!(request.stream, Some(false)); - assert_eq!(request.stop, Some(vec!["\n".to_string()])); - assert_eq!(request.presence_penalty, Some(0.0)); - assert_eq!(request.frequency_penalty, Some(0.0)); + assert_eq!(request.base.model, "gpt-3.5-turbo"); + assert_eq!(request.base.temperature, Some(0.7)); + assert_eq!(request.base.top_p, Some(0.9)); + assert_eq!(request.base.n, Some(1)); + assert_eq!(request.base.max_tokens, Some(100)); + assert_eq!(request.base.stream, Some(false)); + assert_eq!(request.base.stop, Some(vec!["\n".to_string()])); + assert_eq!(request.base.presence_penalty, Some(0.0)); + assert_eq!(request.base.frequency_penalty, Some(0.0)); } } diff --git a/crates/hermesllm/src/providers/groq/mod.rs b/crates/hermesllm/src/providers/groq/mod.rs index e69de29b..84eefe46 100644 --- a/crates/hermesllm/src/providers/groq/mod.rs +++ b/crates/hermesllm/src/providers/groq/mod.rs @@ -0,0 +1,27 @@ +pub mod types; + +use thiserror::Error; + +use crate::providers::groq::types::{GroqRequest, GroqResponse}; + +#[derive(Debug, Error)] +pub enum GroqError { + #[error("json error: {0}")] + JsonParseError(#[from] serde_json::Error), +} + +type Result = std::result::Result; + +impl TryFrom<&[u8]> for GroqRequest { + type Error = GroqError; + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(GroqError::from) + } +} + +impl TryFrom<&[u8]> for GroqResponse { + type Error = GroqError; + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(GroqError::from) + } +} diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index c414c86b..8ffc9f57 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -1,9 +1,12 @@ pub mod openai; pub mod groq; +pub mod deepseek; +pub mod common_types; /// Supported LLM providers. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Provider { - Grok, + Groq, OpenAI, + DeepSeek, } diff --git a/crates/hermesllm/src/providers/openai/types.rs b/crates/hermesllm/src/providers/openai/types.rs index 6695d5a2..43893fa4 100644 --- a/crates/hermesllm/src/providers/openai/types.rs +++ b/crates/hermesllm/src/providers/openai/types.rs @@ -1,38 +1,15 @@ use serde::{Deserialize, Serialize}; +use crate::providers::common_types::{ChatRequestBase, ChatResponseBase}; -/// Represents a request to the OpenAI API (compatible with both chat and completion endpoints). -/// -/// Fields are based on the OpenAI API schema: -/// https://platform.openai.com/docs/api-reference/chat/create #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OpenAIRequest { - /// The model to use (e.g., "gpt-3.5-turbo", "gpt-4"). - pub model: String, - /// The list of messages for chat endpoints (use None for completion). - pub messages: Option>, - /// Sampling temperature to use (higher values = more random). - pub temperature: Option, - /// Nucleus sampling parameter. - pub top_p: Option, - /// How many completions to generate for each prompt/message. - pub n: Option, - /// Maximum number of tokens to generate. - pub max_tokens: Option, - /// Whether to stream back partial progress. - pub stream: Option, - /// Up to 4 sequences where the API will stop generating further tokens. - pub stop: Option>, - /// Penalizes new tokens based on whether they appear in the text so far. - pub presence_penalty: Option, - /// Penalizes new tokens based on their frequency in the text so far. - pub frequency_penalty: Option, + #[serde(flatten)] + pub base: ChatRequestBase, } - -/// Builder for `OpenAIRequest`. #[derive(Debug, Default, Clone)] pub struct OpenAIRequestBuilder { - model: String, - messages: Option>, + model: Option, + messages: Option>, temperature: Option, top_p: Option, n: Option, @@ -44,14 +21,16 @@ pub struct OpenAIRequestBuilder { } impl OpenAIRequestBuilder { - pub fn new(model: impl Into) -> Self { - Self { - model: model.into(), - ..Default::default() - } + pub fn new() -> Self { + Self::default() } - pub fn messages(mut self, messages: Vec) -> Self { + pub fn model(mut self, model: impl Into) -> Self { + self.model = Some(model.into()); + self + } + + pub fn messages(mut self, messages: Vec) -> Self { self.messages = Some(messages); self } @@ -96,9 +75,10 @@ impl OpenAIRequestBuilder { self } - pub fn build(self) -> OpenAIRequest { - OpenAIRequest { - model: self.model, + pub fn build(self) -> Result { + let model = self.model.ok_or("model is required")?; + let base = crate::providers::common_types::ChatRequestBase { + model, messages: self.messages, temperature: self.temperature, top_p: self.top_p, @@ -108,44 +88,21 @@ impl OpenAIRequestBuilder { stop: self.stop, presence_penalty: self.presence_penalty, frequency_penalty: self.frequency_penalty, - } + }; + Ok(OpenAIRequest { base }) } } impl OpenAIRequest { - pub fn builder(model: impl Into) -> OpenAIRequestBuilder { - OpenAIRequestBuilder::new(model) + pub fn builder() -> OpenAIRequestBuilder { + OpenAIRequestBuilder::new() } } -/// Represents a message in the OpenAI chat API. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Message { - /// The role of the message sender ("system", "user", or "assistant"). - pub role: String, - /// The content of the message. - pub content: String, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] pub struct OpenAIResponse { - pub id: String, - pub object: String, - pub created: u64, - pub choices: Vec, - pub usage: Option, + #[serde(flatten)] + pub base: ChatResponseBase, } -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Choice { - pub index: u32, - pub message: Message, - pub finish_reason: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} +pub use crate::providers::common_types::{Message, Choice, Usage};