From 59dbbd6743d810f6e3bb25834ffc45f4afac3122 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Mon, 2 Jun 2025 23:57:03 -0700 Subject: [PATCH] add openai protocol --- crates/Cargo.lock | 7 +- crates/hermesllm/Cargo.toml | 3 + crates/hermesllm/src/lib.rs | 156 +++--------------- crates/hermesllm/src/providers/groq/mod.rs | 0 crates/hermesllm/src/providers/mod.rs | 9 + crates/hermesllm/src/providers/openai/mod.rs | 27 +++ .../hermesllm/src/providers/openai/types.rs | 151 +++++++++++++++++ 7 files changed, 217 insertions(+), 136 deletions(-) create mode 100644 crates/hermesllm/src/providers/groq/mod.rs create mode 100644 crates/hermesllm/src/providers/mod.rs create mode 100644 crates/hermesllm/src/providers/openai/mod.rs create mode 100644 crates/hermesllm/src/providers/openai/types.rs diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 49c29f1b..fe977e6f 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -79,9 +79,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.90" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37bf3594c4c988a53154954629820791dde498571819ae4ca50ca811e060cc95" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "arbitrary" @@ -1011,6 +1011,9 @@ name = "hermesllm" version = "0.1.0" dependencies = [ "common", + "serde", + "serde_json", + "thiserror 2.0.12", ] [[package]] diff --git a/crates/hermesllm/Cargo.toml b/crates/hermesllm/Cargo.toml index fdc60353..b49c8b97 100644 --- a/crates/hermesllm/Cargo.toml +++ b/crates/hermesllm/Cargo.toml @@ -5,3 +5,6 @@ edition = "2021" [dependencies] common = { version = "0.1.0", path = "../common" } +serde = "1.0.219" +serde_json = "1.0.140" +thiserror = "2.0.12" diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 91283ab8..96ff21da 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -1,145 +1,33 @@ //! hermesllm: A library for translating LLM API requests and responses //! between Mistral, Grok, Gemini, and OpenAI-compliant formats. -/// Supported LLM providers. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Provider { - Mistral, - Grok, - Gemini, - OpenAI, -} - -/// OpenAI API request format (placeholder). -#[derive(Debug, Clone)] -pub struct OpenAIRequest { - // Add OpenAI request fields here - pub prompt: String, - // ... -} - -/// OpenAI API response format (placeholder). -#[derive(Debug, Clone)] -pub struct OpenAIResponse { - // Add OpenAI response fields here - pub completion: String, - // ... -} - -/// Mistral API request format (placeholder). -#[derive(Debug, Clone)] -pub struct MistralRequest { - pub input: String, - // ... -} - -/// Mistral API response format (placeholder). -#[derive(Debug, Clone)] -pub struct MistralResponse { - pub output: String, - // ... -} - -/// Grok API request format (placeholder). -#[derive(Debug, Clone)] -pub struct GrokRequest { - pub message: String, - // ... -} - -/// Grok API response format (placeholder). -#[derive(Debug, Clone)] -pub struct GrokResponse { - pub reply: String, - // ... -} - -/// Gemini API request format (placeholder). -#[derive(Debug, Clone)] -pub struct GeminiRequest { - pub query: String, - // ... -} - -/// Gemini API response format (placeholder). -#[derive(Debug, Clone)] -pub struct GeminiResponse { - pub answer: String, - // ... -} - -/// Trait for translating provider-specific requests to OpenAI format. -pub trait ToOpenAIRequest { - fn to_openai(&self) -> OpenAIRequest; -} - -/// Trait for translating OpenAI responses to provider-specific format. -pub trait FromOpenAIResponse: Sized { - fn from_openai(resp: &OpenAIResponse) -> Self; -} - -// Implementations for Mistral -impl ToOpenAIRequest for MistralRequest { - fn to_openai(&self) -> OpenAIRequest { - OpenAIRequest { - prompt: self.input.clone(), - } - } -} -impl FromOpenAIResponse for MistralResponse { - fn from_openai(resp: &OpenAIResponse) -> Self { - MistralResponse { - output: resp.completion.clone(), - } - } -} - -// Implementations for Grok -impl ToOpenAIRequest for GrokRequest { - fn to_openai(&self) -> OpenAIRequest { - OpenAIRequest { - prompt: self.message.clone(), - } - } -} -impl FromOpenAIResponse for GrokResponse { - fn from_openai(resp: &OpenAIResponse) -> Self { - GrokResponse { - reply: resp.completion.clone(), - } - } -} - -// Implementations for Gemini -impl ToOpenAIRequest for GeminiRequest { - fn to_openai(&self) -> OpenAIRequest { - OpenAIRequest { - prompt: self.query.clone(), - } - } -} -impl FromOpenAIResponse for GeminiResponse { - fn from_openai(resp: &OpenAIResponse) -> Self { - GeminiResponse { - answer: resp.completion.clone(), - } - } -} - -// Optionally, add more conversion traits as needed for bidirectional translation. +pub mod providers; #[cfg(test)] mod tests { - use super::*; + use crate::providers::openai::types::OpenAIRequest; #[test] - fn mistral_to_openai_and_back() { - let mistral_req = MistralRequest { input: "Hello".into() }; - let openai_req = mistral_req.to_openai(); - assert_eq!(openai_req.prompt, "Hello"); + fn openai_builder() { + let request = OpenAIRequest::builder("gpt-3.5-turbo") + .temperature(0.7) + .top_p(0.9) + .n(1) + .max_tokens(100) + .stream(false) + .stop(vec!["\n".to_string()]) + .presence_penalty(0.0) + .frequency_penalty(0.0) + .build(); - let openai_resp = OpenAIResponse { completion: "Hi!".into() }; - let mistral_resp = MistralResponse::from_openai(&openai_resp); - assert_eq!(mistral_resp.output, "Hi!"); + assert_eq!(request.model, "gpt-3.5-turbo"); + assert_eq!(request.temperature, Some(0.7)); + assert_eq!(request.top_p, Some(0.9)); + assert_eq!(request.n, Some(1)); + assert_eq!(request.max_tokens, Some(100)); + assert_eq!(request.stream, Some(false)); + assert_eq!(request.stop, Some(vec!["\n".to_string()])); + assert_eq!(request.presence_penalty, Some(0.0)); + assert_eq!(request.frequency_penalty, Some(0.0)); } } diff --git a/crates/hermesllm/src/providers/groq/mod.rs b/crates/hermesllm/src/providers/groq/mod.rs new file mode 100644 index 00000000..e69de29b diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs new file mode 100644 index 00000000..c414c86b --- /dev/null +++ b/crates/hermesllm/src/providers/mod.rs @@ -0,0 +1,9 @@ +pub mod openai; +pub mod groq; + +/// Supported LLM providers. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Provider { + Grok, + OpenAI, +} diff --git a/crates/hermesllm/src/providers/openai/mod.rs b/crates/hermesllm/src/providers/openai/mod.rs new file mode 100644 index 00000000..4060b9bf --- /dev/null +++ b/crates/hermesllm/src/providers/openai/mod.rs @@ -0,0 +1,27 @@ +pub mod types; + +use thiserror::Error; + +use crate::providers::openai::types::{OpenAIRequest, OpenAIResponse}; + +#[derive(Debug, Error)] +pub enum OpenAIError { + #[error("json error: {0}")] + JsonParseError(#[from] serde_json::Error), +} + +type Result = std::result::Result; + +impl TryFrom<&[u8]> for OpenAIRequest { + type Error = OpenAIError; + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(OpenAIError::from) + } +} + +impl TryFrom<&[u8]> for OpenAIResponse { + type Error = OpenAIError; + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(OpenAIError::from) + } +} diff --git a/crates/hermesllm/src/providers/openai/types.rs b/crates/hermesllm/src/providers/openai/types.rs new file mode 100644 index 00000000..6695d5a2 --- /dev/null +++ b/crates/hermesllm/src/providers/openai/types.rs @@ -0,0 +1,151 @@ +use serde::{Deserialize, Serialize}; + +/// Represents a request to the OpenAI API (compatible with both chat and completion endpoints). +/// +/// Fields are based on the OpenAI API schema: +/// https://platform.openai.com/docs/api-reference/chat/create +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenAIRequest { + /// The model to use (e.g., "gpt-3.5-turbo", "gpt-4"). + pub model: String, + /// The list of messages for chat endpoints (use None for completion). + pub messages: Option>, + /// Sampling temperature to use (higher values = more random). + pub temperature: Option, + /// Nucleus sampling parameter. + pub top_p: Option, + /// How many completions to generate for each prompt/message. + pub n: Option, + /// Maximum number of tokens to generate. + pub max_tokens: Option, + /// Whether to stream back partial progress. + pub stream: Option, + /// Up to 4 sequences where the API will stop generating further tokens. + pub stop: Option>, + /// Penalizes new tokens based on whether they appear in the text so far. + pub presence_penalty: Option, + /// Penalizes new tokens based on their frequency in the text so far. + pub frequency_penalty: Option, +} + +/// Builder for `OpenAIRequest`. +#[derive(Debug, Default, Clone)] +pub struct OpenAIRequestBuilder { + model: String, + messages: Option>, + temperature: Option, + top_p: Option, + n: Option, + max_tokens: Option, + stream: Option, + stop: Option>, + presence_penalty: Option, + frequency_penalty: Option, +} + +impl OpenAIRequestBuilder { + pub fn new(model: impl Into) -> Self { + Self { + model: model.into(), + ..Default::default() + } + } + + pub fn messages(mut self, messages: Vec) -> Self { + self.messages = Some(messages); + self + } + + pub fn temperature(mut self, temperature: f32) -> Self { + self.temperature = Some(temperature); + self + } + + pub fn top_p(mut self, top_p: f32) -> Self { + self.top_p = Some(top_p); + self + } + + pub fn n(mut self, n: u32) -> Self { + self.n = Some(n); + self + } + + pub fn max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = Some(max_tokens); + self + } + + pub fn stream(mut self, stream: bool) -> Self { + self.stream = Some(stream); + self + } + + pub fn stop(mut self, stop: Vec) -> Self { + self.stop = Some(stop); + self + } + + pub fn presence_penalty(mut self, presence_penalty: f32) -> Self { + self.presence_penalty = Some(presence_penalty); + self + } + + pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self { + self.frequency_penalty = Some(frequency_penalty); + self + } + + pub fn build(self) -> OpenAIRequest { + OpenAIRequest { + model: self.model, + messages: self.messages, + temperature: self.temperature, + top_p: self.top_p, + n: self.n, + max_tokens: self.max_tokens, + stream: self.stream, + stop: self.stop, + presence_penalty: self.presence_penalty, + frequency_penalty: self.frequency_penalty, + } + } +} + +impl OpenAIRequest { + pub fn builder(model: impl Into) -> OpenAIRequestBuilder { + OpenAIRequestBuilder::new(model) + } +} + +/// Represents a message in the OpenAI chat API. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + /// The role of the message sender ("system", "user", or "assistant"). + pub role: String, + /// The content of the message. + pub content: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct OpenAIResponse { + pub id: String, + pub object: String, + pub created: u64, + pub choices: Vec, + pub usage: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Choice { + pub index: u32, + pub message: Message, + pub finish_reason: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +}