From 25f1b72e7cbd86a25c0cfb86ed2d6320d31d581c Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 5 Jun 2025 16:14:40 -0700 Subject: [PATCH] more changes --- crates/Cargo.lock | 5 +- crates/hermesllm/Cargo.toml | 1 + crates/hermesllm/src/lib.rs | 38 ++++++++++++++ crates/hermesllm/src/providers/groq/mod.rs | 1 - crates/hermesllm/src/providers/groq/types.rs | 19 ------- crates/hermesllm/src/providers/mod.rs | 1 - .../hermesllm/src/providers/openai/types.rs | 49 ++++++++++++++++++- crates/llm_gateway/src/stream_context.rs | 49 +++++++++++++------ 8 files changed, 122 insertions(+), 41 deletions(-) delete mode 100644 crates/hermesllm/src/providers/groq/mod.rs delete mode 100644 crates/hermesllm/src/providers/groq/types.rs diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 3b14a246..7edb0a62 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -1076,6 +1076,7 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" name = "hermesllm" version = "0.1.0" dependencies = [ + "log", "serde", "serde_json", "serde_with", @@ -1642,9 +1643,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] name = "mach2" diff --git a/crates/hermesllm/Cargo.toml b/crates/hermesllm/Cargo.toml index c7917f9a..991c64a5 100644 --- a/crates/hermesllm/Cargo.toml +++ b/crates/hermesllm/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +log = "0.4.27" serde = "1.0.219" serde_json = "1.0.140" serde_with = "3.12.0" diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 3a9c78b0..b26c4eb6 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -1,8 +1,46 @@ //! hermesllm: A library for translating LLM API requests and responses //! between Mistral, Grok, Gemini, and OpenAI-compliant formats. +use std::fmt::Display; + pub mod providers; +pub enum Provider { + Mistral, + Groq, + Gemini, + OpenAI, + Claude, + Github +} + +impl From<&str> for Provider { + fn from(value: &str) -> Self { + match value.to_lowercase().as_str() { + "mistral" => Provider::Mistral, + "groq" => Provider::Groq, + "gemini" => Provider::Gemini, + "openai" => Provider::OpenAI, + "claude" => Provider::Claude, + "github" => Provider::Github, + _ => panic!("Unknown provider: {}", value), + } + } +} + +impl Display for Provider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Provider::Mistral => write!(f, "Mistral"), + Provider::Groq => write!(f, "Groq"), + Provider::Gemini => write!(f, "Gemini"), + Provider::OpenAI => write!(f, "OpenAI"), + Provider::Claude => write!(f, "Claude"), + Provider::Github => write!(f, "Github"), + } + } +} + #[cfg(test)] mod tests { use crate::providers::openai::types::ChatCompletionsRequest; diff --git a/crates/hermesllm/src/providers/groq/mod.rs b/crates/hermesllm/src/providers/groq/mod.rs deleted file mode 100644 index cd408564..00000000 --- a/crates/hermesllm/src/providers/groq/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod types; diff --git a/crates/hermesllm/src/providers/groq/types.rs b/crates/hermesllm/src/providers/groq/types.rs deleted file mode 100644 index 67b7b47b..00000000 --- a/crates/hermesllm/src/providers/groq/types.rs +++ /dev/null @@ -1,19 +0,0 @@ -use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse}; -pub use crate::providers::openai::types::{Choice, Message, Usage}; - -use serde::{Deserialize, Serialize}; -use serde_with::skip_serializing_none; - -#[skip_serializing_none] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GroqRequest { - #[serde(flatten)] - pub base: ChatCompletionsRequest, -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GroqResponse { - #[serde(flatten)] - pub base: ChatCompletionsResponse, -} diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index 8ceda63a..5ee6632d 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -1,3 +1,2 @@ pub mod deepseek; -pub mod groq; pub mod openai; diff --git a/crates/hermesllm/src/providers/openai/types.rs b/crates/hermesllm/src/providers/openai/types.rs index 6693656c..265fd4fe 100644 --- a/crates/hermesllm/src/providers/openai/types.rs +++ b/crates/hermesllm/src/providers/openai/types.rs @@ -1,18 +1,22 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{Value}; use serde_with::skip_serializing_none; use std::convert::TryFrom; use std::str; use thiserror::Error; +use crate::Provider; + #[derive(Debug, Error)] pub enum OpenAIError { #[error("json error: {0}")] JsonParseError(#[from] serde_json::Error), #[error("utf8 parsing error: {0}")] Utf8Error(#[from] std::str::Utf8Error), + #[error("unsupported provider: {provider}")] + UnsupportedProvider { provider: String }, } type Result = std::result::Result; @@ -117,6 +121,30 @@ impl TryFrom<&[u8]> for ChatCompletionsResponse { } } +impl<'a> TryFrom<(&'a [u8], &'a Provider)> for ChatCompletionsResponse { + type Error = OpenAIError; + + fn try_from(input: (&'a [u8], &'a Provider)) -> Result { + // Use input.provider as needed, if necessary + serde_json::from_slice(input.0).map_err(OpenAIError::from) + } +} + +impl ChatCompletionsRequest { + pub fn to_bytes(&self, provider: Provider) -> Result> { + match provider { + Provider::OpenAI + | Provider::Mistral + | Provider::Groq + | Provider::Gemini + | Provider::Claude => serde_json::to_vec(self).map_err(OpenAIError::from), + _ => Err(OpenAIError::UnsupportedProvider { + provider: provider.to_string(), + }), + } + } +} + #[skip_serializing_none] #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Choice { @@ -133,10 +161,17 @@ pub struct Usage { pub total_tokens: usize, } +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeltaMessage { + pub role: Option, + pub content: Option, +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct StreamChoice { pub index: u32, - pub delta: Message, + pub delta: DeltaMessage, pub finish_reason: Option, } @@ -193,6 +228,16 @@ where } } +impl<'a> TryFrom<(&'a [u8], &'a Provider)> for SseChatCompletionIter> { + type Error = OpenAIError; + + fn try_from(input: (&'a [u8], &'a Provider)) -> Result { + let s = std::str::from_utf8(input.0)?; + // Use input.provider as needed + Ok(SseChatCompletionIter::new(s.lines())) + } +} + impl<'a> TryFrom<&'a [u8]> for SseChatCompletionIter> { type Error = OpenAIError; diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index d7d4ad23..d02e6987 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -14,6 +14,7 @@ use hermesllm::providers::openai::types::{ChatCompletionsRequest, SseChatComplet use hermesllm::providers::openai::types::{ ChatCompletionsResponse, ContentType, Message, StreamOptions, }; +use hermesllm::Provider; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::hostcalls::get_current_time; @@ -338,13 +339,6 @@ impl HttpContext for StreamContext { model_name.unwrap_or(&"None".to_string()), ); - let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap(); - - debug!( - "on_http_request_body: request body: {}", - chat_completion_request_str - ); - if deserialized_body.stream.unwrap_or_default() { self.streaming_response = true; } @@ -379,7 +373,23 @@ impl HttpContext for StreamContext { return Action::Continue; } - self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes()); + // convert chat completion request to llm provider specific request + let deserialized_body_bytes = match deserialized_body.to_bytes(hermesllm::Provider::OpenAI) + { + Ok(bytes) => bytes, + Err(e) => { + warn!("Failed to serialize request body: {}", e); + self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST)); + return Action::Pause; + } + }; + + debug!( + "on_http_request_body: request body string: {}", + String::from_utf8_lossy(&deserialized_body_bytes) + ); + + self.set_http_request_body(0, body_size, &deserialized_body_bytes); Action::Continue } @@ -534,9 +544,12 @@ impl HttpContext for StreamContext { } }; + let llm_provider_str = self.llm_provider().provider_interface.to_string(); + let hermes_llm_provider = Provider::from(llm_provider_str.as_str()); + if self.streaming_response { let chat_completions_chunk_response_events = - match SseChatCompletionIter::try_from(body.as_slice()) { + match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) { Ok(events) => events, Err(e) => { warn!("could not parse response: {}", e); @@ -580,14 +593,18 @@ impl HttpContext for StreamContext { } } else { debug!("non streaming response"); - let chat_completions_response: ChatCompletionsResponse = - match serde_json::from_slice(body.as_slice()) { + let chat_completions_response = + match ChatCompletionsResponse::try_from((body.as_slice(), &hermes_llm_provider)) { Ok(de) => de, - Err(err) => { - info!( - "non chat-completion compliant response received err: {}, body: {:?}", - err, - String::from_utf8(body) + Err(e) => { + warn!("could not parse response: {}", e); + debug!( + "on_http_response_body: response body: {}", + String::from_utf8_lossy(&body) + ); + self.send_server_error( + ServerError::OpenAIPError(e), + Some(StatusCode::BAD_REQUEST), ); return Action::Continue; }