more changes

This commit is contained in:
Adil Hafeez 2025-06-05 16:14:40 -07:00
parent 22fde1f333
commit 25f1b72e7c
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
8 changed files with 122 additions and 41 deletions

5
crates/Cargo.lock generated
View file

@ -1076,6 +1076,7 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
name = "hermesllm"
version = "0.1.0"
dependencies = [
"log",
"serde",
"serde_json",
"serde_with",
@ -1642,9 +1643,9 @@ dependencies = [
[[package]]
name = "log"
version = "0.4.22"
version = "0.4.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
[[package]]
name = "mach2"

View file

@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
log = "0.4.27"
serde = "1.0.219"
serde_json = "1.0.140"
serde_with = "3.12.0"

View file

@ -1,8 +1,46 @@
//! hermesllm: A library for translating LLM API requests and responses
//! between Mistral, Grok, Gemini, and OpenAI-compliant formats.
use std::fmt::Display;
pub mod providers;
pub enum Provider {
Mistral,
Groq,
Gemini,
OpenAI,
Claude,
Github
}
impl From<&str> for Provider {
fn from(value: &str) -> Self {
match value.to_lowercase().as_str() {
"mistral" => Provider::Mistral,
"groq" => Provider::Groq,
"gemini" => Provider::Gemini,
"openai" => Provider::OpenAI,
"claude" => Provider::Claude,
"github" => Provider::Github,
_ => panic!("Unknown provider: {}", value),
}
}
}
impl Display for Provider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Provider::Mistral => write!(f, "Mistral"),
Provider::Groq => write!(f, "Groq"),
Provider::Gemini => write!(f, "Gemini"),
Provider::OpenAI => write!(f, "OpenAI"),
Provider::Claude => write!(f, "Claude"),
Provider::Github => write!(f, "Github"),
}
}
}
#[cfg(test)]
mod tests {
use crate::providers::openai::types::ChatCompletionsRequest;

View file

@ -1 +0,0 @@
pub mod types;

View file

@ -1,19 +0,0 @@
use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse};
pub use crate::providers::openai::types::{Choice, Message, Usage};
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroqRequest {
#[serde(flatten)]
pub base: ChatCompletionsRequest,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroqResponse {
#[serde(flatten)]
pub base: ChatCompletionsResponse,
}

View file

@ -1,3 +1,2 @@
pub mod deepseek;
pub mod groq;
pub mod openai;

View file

@ -1,18 +1,22 @@
use std::fmt::Display;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_json::{Value};
use serde_with::skip_serializing_none;
use std::convert::TryFrom;
use std::str;
use thiserror::Error;
use crate::Provider;
#[derive(Debug, Error)]
pub enum OpenAIError {
#[error("json error: {0}")]
JsonParseError(#[from] serde_json::Error),
#[error("utf8 parsing error: {0}")]
Utf8Error(#[from] std::str::Utf8Error),
#[error("unsupported provider: {provider}")]
UnsupportedProvider { provider: String },
}
type Result<T> = std::result::Result<T, OpenAIError>;
@ -117,6 +121,30 @@ impl TryFrom<&[u8]> for ChatCompletionsResponse {
}
}
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for ChatCompletionsResponse {
type Error = OpenAIError;
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
// Use input.provider as needed, if necessary
serde_json::from_slice(input.0).map_err(OpenAIError::from)
}
}
impl ChatCompletionsRequest {
pub fn to_bytes(&self, provider: Provider) -> Result<Vec<u8>> {
match provider {
Provider::OpenAI
| Provider::Mistral
| Provider::Groq
| Provider::Gemini
| Provider::Claude => serde_json::to_vec(self).map_err(OpenAIError::from),
_ => Err(OpenAIError::UnsupportedProvider {
provider: provider.to_string(),
}),
}
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Choice {
@ -133,10 +161,17 @@ pub struct Usage {
pub total_tokens: usize,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeltaMessage {
pub role: Option<String>,
pub content: Option<ContentType>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamChoice {
pub index: u32,
pub delta: Message,
pub delta: DeltaMessage,
pub finish_reason: Option<String>,
}
@ -193,6 +228,16 @@ where
}
}
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for SseChatCompletionIter<str::Lines<'a>> {
type Error = OpenAIError;
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
let s = std::str::from_utf8(input.0)?;
// Use input.provider as needed
Ok(SseChatCompletionIter::new(s.lines()))
}
}
impl<'a> TryFrom<&'a [u8]> for SseChatCompletionIter<str::Lines<'a>> {
type Error = OpenAIError;

View file

@ -14,6 +14,7 @@ use hermesllm::providers::openai::types::{ChatCompletionsRequest, SseChatComplet
use hermesllm::providers::openai::types::{
ChatCompletionsResponse, ContentType, Message, StreamOptions,
};
use hermesllm::Provider;
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::hostcalls::get_current_time;
@ -338,13 +339,6 @@ impl HttpContext for StreamContext {
model_name.unwrap_or(&"None".to_string()),
);
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
debug!(
"on_http_request_body: request body: {}",
chat_completion_request_str
);
if deserialized_body.stream.unwrap_or_default() {
self.streaming_response = true;
}
@ -379,7 +373,23 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
// convert chat completion request to llm provider specific request
let deserialized_body_bytes = match deserialized_body.to_bytes(hermesllm::Provider::OpenAI)
{
Ok(bytes) => bytes,
Err(e) => {
warn!("Failed to serialize request body: {}", e);
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
return Action::Pause;
}
};
debug!(
"on_http_request_body: request body string: {}",
String::from_utf8_lossy(&deserialized_body_bytes)
);
self.set_http_request_body(0, body_size, &deserialized_body_bytes);
Action::Continue
}
@ -534,9 +544,12 @@ impl HttpContext for StreamContext {
}
};
let llm_provider_str = self.llm_provider().provider_interface.to_string();
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
if self.streaming_response {
let chat_completions_chunk_response_events =
match SseChatCompletionIter::try_from(body.as_slice()) {
match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) {
Ok(events) => events,
Err(e) => {
warn!("could not parse response: {}", e);
@ -580,14 +593,18 @@ impl HttpContext for StreamContext {
}
} else {
debug!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_slice(body.as_slice()) {
let chat_completions_response =
match ChatCompletionsResponse::try_from((body.as_slice(), &hermes_llm_provider)) {
Ok(de) => de,
Err(err) => {
info!(
"non chat-completion compliant response received err: {}, body: {:?}",
err,
String::from_utf8(body)
Err(e) => {
warn!("could not parse response: {}", e);
debug!(
"on_http_response_body: response body: {}",
String::from_utf8_lossy(&body)
);
self.send_server_error(
ServerError::OpenAIPError(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}