add openai protocol

This commit is contained in:
Adil Hafeez 2025-06-02 23:57:03 -07:00
parent 21ca21dc3c
commit 59dbbd6743
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
7 changed files with 217 additions and 136 deletions

7
crates/Cargo.lock generated
View file

@ -79,9 +79,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.90"
version = "1.0.98"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37bf3594c4c988a53154954629820791dde498571819ae4ca50ca811e060cc95"
checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487"
[[package]]
name = "arbitrary"
@ -1011,6 +1011,9 @@ name = "hermesllm"
version = "0.1.0"
dependencies = [
"common",
"serde",
"serde_json",
"thiserror 2.0.12",
]
[[package]]

View file

@ -5,3 +5,6 @@ edition = "2021"
[dependencies]
common = { version = "0.1.0", path = "../common" }
serde = "1.0.219"
serde_json = "1.0.140"
thiserror = "2.0.12"

View file

@ -1,145 +1,33 @@
//! hermesllm: A library for translating LLM API requests and responses
//! between Mistral, Grok, Gemini, and OpenAI-compliant formats.
/// Supported LLM providers.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Provider {
Mistral,
Grok,
Gemini,
OpenAI,
}
/// OpenAI API request format (placeholder).
#[derive(Debug, Clone)]
pub struct OpenAIRequest {
// Add OpenAI request fields here
pub prompt: String,
// ...
}
/// OpenAI API response format (placeholder).
#[derive(Debug, Clone)]
pub struct OpenAIResponse {
// Add OpenAI response fields here
pub completion: String,
// ...
}
/// Mistral API request format (placeholder).
#[derive(Debug, Clone)]
pub struct MistralRequest {
pub input: String,
// ...
}
/// Mistral API response format (placeholder).
#[derive(Debug, Clone)]
pub struct MistralResponse {
pub output: String,
// ...
}
/// Grok API request format (placeholder).
#[derive(Debug, Clone)]
pub struct GrokRequest {
pub message: String,
// ...
}
/// Grok API response format (placeholder).
#[derive(Debug, Clone)]
pub struct GrokResponse {
pub reply: String,
// ...
}
/// Gemini API request format (placeholder).
#[derive(Debug, Clone)]
pub struct GeminiRequest {
pub query: String,
// ...
}
/// Gemini API response format (placeholder).
#[derive(Debug, Clone)]
pub struct GeminiResponse {
pub answer: String,
// ...
}
/// Trait for translating provider-specific requests to OpenAI format.
pub trait ToOpenAIRequest {
fn to_openai(&self) -> OpenAIRequest;
}
/// Trait for translating OpenAI responses to provider-specific format.
pub trait FromOpenAIResponse: Sized {
fn from_openai(resp: &OpenAIResponse) -> Self;
}
// Implementations for Mistral
impl ToOpenAIRequest for MistralRequest {
fn to_openai(&self) -> OpenAIRequest {
OpenAIRequest {
prompt: self.input.clone(),
}
}
}
impl FromOpenAIResponse for MistralResponse {
fn from_openai(resp: &OpenAIResponse) -> Self {
MistralResponse {
output: resp.completion.clone(),
}
}
}
// Implementations for Grok
impl ToOpenAIRequest for GrokRequest {
fn to_openai(&self) -> OpenAIRequest {
OpenAIRequest {
prompt: self.message.clone(),
}
}
}
impl FromOpenAIResponse for GrokResponse {
fn from_openai(resp: &OpenAIResponse) -> Self {
GrokResponse {
reply: resp.completion.clone(),
}
}
}
// Implementations for Gemini
impl ToOpenAIRequest for GeminiRequest {
fn to_openai(&self) -> OpenAIRequest {
OpenAIRequest {
prompt: self.query.clone(),
}
}
}
impl FromOpenAIResponse for GeminiResponse {
fn from_openai(resp: &OpenAIResponse) -> Self {
GeminiResponse {
answer: resp.completion.clone(),
}
}
}
// Optionally, add more conversion traits as needed for bidirectional translation.
pub mod providers;
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::openai::types::OpenAIRequest;
#[test]
fn mistral_to_openai_and_back() {
let mistral_req = MistralRequest { input: "Hello".into() };
let openai_req = mistral_req.to_openai();
assert_eq!(openai_req.prompt, "Hello");
fn openai_builder() {
let request = OpenAIRequest::builder("gpt-3.5-turbo")
.temperature(0.7)
.top_p(0.9)
.n(1)
.max_tokens(100)
.stream(false)
.stop(vec!["\n".to_string()])
.presence_penalty(0.0)
.frequency_penalty(0.0)
.build();
let openai_resp = OpenAIResponse { completion: "Hi!".into() };
let mistral_resp = MistralResponse::from_openai(&openai_resp);
assert_eq!(mistral_resp.output, "Hi!");
assert_eq!(request.model, "gpt-3.5-turbo");
assert_eq!(request.temperature, Some(0.7));
assert_eq!(request.top_p, Some(0.9));
assert_eq!(request.n, Some(1));
assert_eq!(request.max_tokens, Some(100));
assert_eq!(request.stream, Some(false));
assert_eq!(request.stop, Some(vec!["\n".to_string()]));
assert_eq!(request.presence_penalty, Some(0.0));
assert_eq!(request.frequency_penalty, Some(0.0));
}
}

View file

@ -0,0 +1,9 @@
pub mod openai;
pub mod groq;
/// Supported LLM providers.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Provider {
Grok,
OpenAI,
}

View file

@ -0,0 +1,27 @@
pub mod types;
use thiserror::Error;
use crate::providers::openai::types::{OpenAIRequest, OpenAIResponse};
#[derive(Debug, Error)]
pub enum OpenAIError {
#[error("json error: {0}")]
JsonParseError(#[from] serde_json::Error),
}
type Result<T> = std::result::Result<T, OpenAIError>;
impl TryFrom<&[u8]> for OpenAIRequest {
type Error = OpenAIError;
fn try_from(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(OpenAIError::from)
}
}
impl TryFrom<&[u8]> for OpenAIResponse {
type Error = OpenAIError;
fn try_from(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(OpenAIError::from)
}
}

View file

@ -0,0 +1,151 @@
use serde::{Deserialize, Serialize};
/// Represents a request to the OpenAI API (compatible with both chat and completion endpoints).
///
/// Fields are based on the OpenAI API schema:
/// https://platform.openai.com/docs/api-reference/chat/create
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIRequest {
/// The model to use (e.g., "gpt-3.5-turbo", "gpt-4").
pub model: String,
/// The list of messages for chat endpoints (use None for completion).
pub messages: Option<Vec<Message>>,
/// Sampling temperature to use (higher values = more random).
pub temperature: Option<f32>,
/// Nucleus sampling parameter.
pub top_p: Option<f32>,
/// How many completions to generate for each prompt/message.
pub n: Option<u32>,
/// Maximum number of tokens to generate.
pub max_tokens: Option<u32>,
/// Whether to stream back partial progress.
pub stream: Option<bool>,
/// Up to 4 sequences where the API will stop generating further tokens.
pub stop: Option<Vec<String>>,
/// Penalizes new tokens based on whether they appear in the text so far.
pub presence_penalty: Option<f32>,
/// Penalizes new tokens based on their frequency in the text so far.
pub frequency_penalty: Option<f32>,
}
/// Builder for `OpenAIRequest`.
#[derive(Debug, Default, Clone)]
pub struct OpenAIRequestBuilder {
model: String,
messages: Option<Vec<Message>>,
temperature: Option<f32>,
top_p: Option<f32>,
n: Option<u32>,
max_tokens: Option<u32>,
stream: Option<bool>,
stop: Option<Vec<String>>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
}
impl OpenAIRequestBuilder {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
..Default::default()
}
}
pub fn messages(mut self, messages: Vec<Message>) -> Self {
self.messages = Some(messages);
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn n(mut self, n: u32) -> Self {
self.n = Some(n);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn stream(mut self, stream: bool) -> Self {
self.stream = Some(stream);
self
}
pub fn stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
self.presence_penalty = Some(presence_penalty);
self
}
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.frequency_penalty = Some(frequency_penalty);
self
}
pub fn build(self) -> OpenAIRequest {
OpenAIRequest {
model: self.model,
messages: self.messages,
temperature: self.temperature,
top_p: self.top_p,
n: self.n,
max_tokens: self.max_tokens,
stream: self.stream,
stop: self.stop,
presence_penalty: self.presence_penalty,
frequency_penalty: self.frequency_penalty,
}
}
}
impl OpenAIRequest {
pub fn builder(model: impl Into<String>) -> OpenAIRequestBuilder {
OpenAIRequestBuilder::new(model)
}
}
/// Represents a message in the OpenAI chat API.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
/// The role of the message sender ("system", "user", or "assistant").
pub role: String,
/// The content of the message.
pub content: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OpenAIResponse {
pub id: String,
pub object: String,
pub created: u64,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Choice {
pub index: u32,
pub message: Message,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}