more changes

This commit is contained in:
Adil Hafeez 2025-06-05 16:14:40 -07:00
parent 22fde1f333
commit 25f1b72e7c
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
8 changed files with 122 additions and 41 deletions

View file

@ -1,8 +1,46 @@
//! hermesllm: A library for translating LLM API requests and responses
//! between Mistral, Grok, Gemini, and OpenAI-compliant formats.
use std::fmt::Display;
pub mod providers;
pub enum Provider {
Mistral,
Groq,
Gemini,
OpenAI,
Claude,
Github
}
impl From<&str> for Provider {
fn from(value: &str) -> Self {
match value.to_lowercase().as_str() {
"mistral" => Provider::Mistral,
"groq" => Provider::Groq,
"gemini" => Provider::Gemini,
"openai" => Provider::OpenAI,
"claude" => Provider::Claude,
"github" => Provider::Github,
_ => panic!("Unknown provider: {}", value),
}
}
}
impl Display for Provider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Provider::Mistral => write!(f, "Mistral"),
Provider::Groq => write!(f, "Groq"),
Provider::Gemini => write!(f, "Gemini"),
Provider::OpenAI => write!(f, "OpenAI"),
Provider::Claude => write!(f, "Claude"),
Provider::Github => write!(f, "Github"),
}
}
}
#[cfg(test)]
mod tests {
use crate::providers::openai::types::ChatCompletionsRequest;

View file

@ -1 +0,0 @@
pub mod types;

View file

@ -1,19 +0,0 @@
use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse};
pub use crate::providers::openai::types::{Choice, Message, Usage};
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroqRequest {
#[serde(flatten)]
pub base: ChatCompletionsRequest,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroqResponse {
#[serde(flatten)]
pub base: ChatCompletionsResponse,
}

View file

@ -1,3 +1,2 @@
pub mod deepseek;
pub mod groq;
pub mod openai;

View file

@ -1,18 +1,22 @@
use std::fmt::Display;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_json::{Value};
use serde_with::skip_serializing_none;
use std::convert::TryFrom;
use std::str;
use thiserror::Error;
use crate::Provider;
#[derive(Debug, Error)]
pub enum OpenAIError {
#[error("json error: {0}")]
JsonParseError(#[from] serde_json::Error),
#[error("utf8 parsing error: {0}")]
Utf8Error(#[from] std::str::Utf8Error),
#[error("unsupported provider: {provider}")]
UnsupportedProvider { provider: String },
}
type Result<T> = std::result::Result<T, OpenAIError>;
@ -117,6 +121,30 @@ impl TryFrom<&[u8]> for ChatCompletionsResponse {
}
}
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for ChatCompletionsResponse {
type Error = OpenAIError;
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
// Use input.provider as needed, if necessary
serde_json::from_slice(input.0).map_err(OpenAIError::from)
}
}
impl ChatCompletionsRequest {
pub fn to_bytes(&self, provider: Provider) -> Result<Vec<u8>> {
match provider {
Provider::OpenAI
| Provider::Mistral
| Provider::Groq
| Provider::Gemini
| Provider::Claude => serde_json::to_vec(self).map_err(OpenAIError::from),
_ => Err(OpenAIError::UnsupportedProvider {
provider: provider.to_string(),
}),
}
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Choice {
@ -133,10 +161,17 @@ pub struct Usage {
pub total_tokens: usize,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeltaMessage {
pub role: Option<String>,
pub content: Option<ContentType>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamChoice {
pub index: u32,
pub delta: Message,
pub delta: DeltaMessage,
pub finish_reason: Option<String>,
}
@ -193,6 +228,16 @@ where
}
}
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for SseChatCompletionIter<str::Lines<'a>> {
type Error = OpenAIError;
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
let s = std::str::from_utf8(input.0)?;
// Use input.provider as needed
Ok(SseChatCompletionIter::new(s.lines()))
}
}
impl<'a> TryFrom<&'a [u8]> for SseChatCompletionIter<str::Lines<'a>> {
type Error = OpenAIError;