mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add openai protocol
This commit is contained in:
parent
21ca21dc3c
commit
59dbbd6743
7 changed files with 217 additions and 136 deletions
7
crates/Cargo.lock
generated
7
crates/Cargo.lock
generated
|
|
@ -79,9 +79,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.90"
|
||||
version = "1.0.98"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37bf3594c4c988a53154954629820791dde498571819ae4ca50ca811e060cc95"
|
||||
checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487"
|
||||
|
||||
[[package]]
|
||||
name = "arbitrary"
|
||||
|
|
@ -1011,6 +1011,9 @@ name = "hermesllm"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"common",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -5,3 +5,6 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
common = { version = "0.1.0", path = "../common" }
|
||||
serde = "1.0.219"
|
||||
serde_json = "1.0.140"
|
||||
thiserror = "2.0.12"
|
||||
|
|
|
|||
|
|
@ -1,145 +1,33 @@
|
|||
//! hermesllm: A library for translating LLM API requests and responses
|
||||
//! between Mistral, Grok, Gemini, and OpenAI-compliant formats.
|
||||
|
||||
/// Supported LLM providers.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Provider {
|
||||
Mistral,
|
||||
Grok,
|
||||
Gemini,
|
||||
OpenAI,
|
||||
}
|
||||
|
||||
/// OpenAI API request format (placeholder).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OpenAIRequest {
|
||||
// Add OpenAI request fields here
|
||||
pub prompt: String,
|
||||
// ...
|
||||
}
|
||||
|
||||
/// OpenAI API response format (placeholder).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OpenAIResponse {
|
||||
// Add OpenAI response fields here
|
||||
pub completion: String,
|
||||
// ...
|
||||
}
|
||||
|
||||
/// Mistral API request format (placeholder).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MistralRequest {
|
||||
pub input: String,
|
||||
// ...
|
||||
}
|
||||
|
||||
/// Mistral API response format (placeholder).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MistralResponse {
|
||||
pub output: String,
|
||||
// ...
|
||||
}
|
||||
|
||||
/// Grok API request format (placeholder).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GrokRequest {
|
||||
pub message: String,
|
||||
// ...
|
||||
}
|
||||
|
||||
/// Grok API response format (placeholder).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GrokResponse {
|
||||
pub reply: String,
|
||||
// ...
|
||||
}
|
||||
|
||||
/// Gemini API request format (placeholder).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GeminiRequest {
|
||||
pub query: String,
|
||||
// ...
|
||||
}
|
||||
|
||||
/// Gemini API response format (placeholder).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GeminiResponse {
|
||||
pub answer: String,
|
||||
// ...
|
||||
}
|
||||
|
||||
/// Trait for translating provider-specific requests to OpenAI format.
|
||||
pub trait ToOpenAIRequest {
|
||||
fn to_openai(&self) -> OpenAIRequest;
|
||||
}
|
||||
|
||||
/// Trait for translating OpenAI responses to provider-specific format.
|
||||
pub trait FromOpenAIResponse: Sized {
|
||||
fn from_openai(resp: &OpenAIResponse) -> Self;
|
||||
}
|
||||
|
||||
// Implementations for Mistral
|
||||
impl ToOpenAIRequest for MistralRequest {
|
||||
fn to_openai(&self) -> OpenAIRequest {
|
||||
OpenAIRequest {
|
||||
prompt: self.input.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl FromOpenAIResponse for MistralResponse {
|
||||
fn from_openai(resp: &OpenAIResponse) -> Self {
|
||||
MistralResponse {
|
||||
output: resp.completion.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Implementations for Grok
|
||||
impl ToOpenAIRequest for GrokRequest {
|
||||
fn to_openai(&self) -> OpenAIRequest {
|
||||
OpenAIRequest {
|
||||
prompt: self.message.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl FromOpenAIResponse for GrokResponse {
|
||||
fn from_openai(resp: &OpenAIResponse) -> Self {
|
||||
GrokResponse {
|
||||
reply: resp.completion.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Implementations for Gemini
|
||||
impl ToOpenAIRequest for GeminiRequest {
|
||||
fn to_openai(&self) -> OpenAIRequest {
|
||||
OpenAIRequest {
|
||||
prompt: self.query.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl FromOpenAIResponse for GeminiResponse {
|
||||
fn from_openai(resp: &OpenAIResponse) -> Self {
|
||||
GeminiResponse {
|
||||
answer: resp.completion.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Optionally, add more conversion traits as needed for bidirectional translation.
|
||||
pub mod providers;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::providers::openai::types::OpenAIRequest;
|
||||
|
||||
#[test]
|
||||
fn mistral_to_openai_and_back() {
|
||||
let mistral_req = MistralRequest { input: "Hello".into() };
|
||||
let openai_req = mistral_req.to_openai();
|
||||
assert_eq!(openai_req.prompt, "Hello");
|
||||
fn openai_builder() {
|
||||
let request = OpenAIRequest::builder("gpt-3.5-turbo")
|
||||
.temperature(0.7)
|
||||
.top_p(0.9)
|
||||
.n(1)
|
||||
.max_tokens(100)
|
||||
.stream(false)
|
||||
.stop(vec!["\n".to_string()])
|
||||
.presence_penalty(0.0)
|
||||
.frequency_penalty(0.0)
|
||||
.build();
|
||||
|
||||
let openai_resp = OpenAIResponse { completion: "Hi!".into() };
|
||||
let mistral_resp = MistralResponse::from_openai(&openai_resp);
|
||||
assert_eq!(mistral_resp.output, "Hi!");
|
||||
assert_eq!(request.model, "gpt-3.5-turbo");
|
||||
assert_eq!(request.temperature, Some(0.7));
|
||||
assert_eq!(request.top_p, Some(0.9));
|
||||
assert_eq!(request.n, Some(1));
|
||||
assert_eq!(request.max_tokens, Some(100));
|
||||
assert_eq!(request.stream, Some(false));
|
||||
assert_eq!(request.stop, Some(vec!["\n".to_string()]));
|
||||
assert_eq!(request.presence_penalty, Some(0.0));
|
||||
assert_eq!(request.frequency_penalty, Some(0.0));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
0
crates/hermesllm/src/providers/groq/mod.rs
Normal file
0
crates/hermesllm/src/providers/groq/mod.rs
Normal file
9
crates/hermesllm/src/providers/mod.rs
Normal file
9
crates/hermesllm/src/providers/mod.rs
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
pub mod openai;
|
||||
pub mod groq;
|
||||
|
||||
/// Supported LLM providers.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Provider {
|
||||
Grok,
|
||||
OpenAI,
|
||||
}
|
||||
27
crates/hermesllm/src/providers/openai/mod.rs
Normal file
27
crates/hermesllm/src/providers/openai/mod.rs
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
pub mod types;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::providers::openai::types::{OpenAIRequest, OpenAIResponse};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OpenAIError {
|
||||
#[error("json error: {0}")]
|
||||
JsonParseError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, OpenAIError>;
|
||||
|
||||
impl TryFrom<&[u8]> for OpenAIRequest {
|
||||
type Error = OpenAIError;
|
||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIError::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for OpenAIResponse {
|
||||
type Error = OpenAIError;
|
||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIError::from)
|
||||
}
|
||||
}
|
||||
151
crates/hermesllm/src/providers/openai/types.rs
Normal file
151
crates/hermesllm/src/providers/openai/types.rs
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Represents a request to the OpenAI API (compatible with both chat and completion endpoints).
|
||||
///
|
||||
/// Fields are based on the OpenAI API schema:
|
||||
/// https://platform.openai.com/docs/api-reference/chat/create
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIRequest {
|
||||
/// The model to use (e.g., "gpt-3.5-turbo", "gpt-4").
|
||||
pub model: String,
|
||||
/// The list of messages for chat endpoints (use None for completion).
|
||||
pub messages: Option<Vec<Message>>,
|
||||
/// Sampling temperature to use (higher values = more random).
|
||||
pub temperature: Option<f32>,
|
||||
/// Nucleus sampling parameter.
|
||||
pub top_p: Option<f32>,
|
||||
/// How many completions to generate for each prompt/message.
|
||||
pub n: Option<u32>,
|
||||
/// Maximum number of tokens to generate.
|
||||
pub max_tokens: Option<u32>,
|
||||
/// Whether to stream back partial progress.
|
||||
pub stream: Option<bool>,
|
||||
/// Up to 4 sequences where the API will stop generating further tokens.
|
||||
pub stop: Option<Vec<String>>,
|
||||
/// Penalizes new tokens based on whether they appear in the text so far.
|
||||
pub presence_penalty: Option<f32>,
|
||||
/// Penalizes new tokens based on their frequency in the text so far.
|
||||
pub frequency_penalty: Option<f32>,
|
||||
}
|
||||
|
||||
/// Builder for `OpenAIRequest`.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct OpenAIRequestBuilder {
|
||||
model: String,
|
||||
messages: Option<Vec<Message>>,
|
||||
temperature: Option<f32>,
|
||||
top_p: Option<f32>,
|
||||
n: Option<u32>,
|
||||
max_tokens: Option<u32>,
|
||||
stream: Option<bool>,
|
||||
stop: Option<Vec<String>>,
|
||||
presence_penalty: Option<f32>,
|
||||
frequency_penalty: Option<f32>,
|
||||
}
|
||||
|
||||
impl OpenAIRequestBuilder {
|
||||
pub fn new(model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
model: model.into(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn messages(mut self, messages: Vec<Message>) -> Self {
|
||||
self.messages = Some(messages);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn temperature(mut self, temperature: f32) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn top_p(mut self, top_p: f32) -> Self {
|
||||
self.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn n(mut self, n: u32) -> Self {
|
||||
self.n = Some(n);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
|
||||
self.max_tokens = Some(max_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stream(mut self, stream: bool) -> Self {
|
||||
self.stream = Some(stream);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stop(mut self, stop: Vec<String>) -> Self {
|
||||
self.stop = Some(stop);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
|
||||
self.presence_penalty = Some(presence_penalty);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
|
||||
self.frequency_penalty = Some(frequency_penalty);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> OpenAIRequest {
|
||||
OpenAIRequest {
|
||||
model: self.model,
|
||||
messages: self.messages,
|
||||
temperature: self.temperature,
|
||||
top_p: self.top_p,
|
||||
n: self.n,
|
||||
max_tokens: self.max_tokens,
|
||||
stream: self.stream,
|
||||
stop: self.stop,
|
||||
presence_penalty: self.presence_penalty,
|
||||
frequency_penalty: self.frequency_penalty,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIRequest {
|
||||
pub fn builder(model: impl Into<String>) -> OpenAIRequestBuilder {
|
||||
OpenAIRequestBuilder::new(model)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a message in the OpenAI chat API.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
/// The role of the message sender ("system", "user", or "assistant").
|
||||
pub role: String,
|
||||
/// The content of the message.
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct OpenAIResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: Message,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue