mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
more changes
This commit is contained in:
parent
22fde1f333
commit
25f1b72e7c
8 changed files with 122 additions and 41 deletions
|
|
@ -1,8 +1,46 @@
|
|||
//! hermesllm: A library for translating LLM API requests and responses
|
||||
//! between Mistral, Grok, Gemini, and OpenAI-compliant formats.
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
pub mod providers;
|
||||
|
||||
pub enum Provider {
|
||||
Mistral,
|
||||
Groq,
|
||||
Gemini,
|
||||
OpenAI,
|
||||
Claude,
|
||||
Github
|
||||
}
|
||||
|
||||
impl From<&str> for Provider {
|
||||
fn from(value: &str) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
"mistral" => Provider::Mistral,
|
||||
"groq" => Provider::Groq,
|
||||
"gemini" => Provider::Gemini,
|
||||
"openai" => Provider::OpenAI,
|
||||
"claude" => Provider::Claude,
|
||||
"github" => Provider::Github,
|
||||
_ => panic!("Unknown provider: {}", value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Provider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Provider::Mistral => write!(f, "Mistral"),
|
||||
Provider::Groq => write!(f, "Groq"),
|
||||
Provider::Gemini => write!(f, "Gemini"),
|
||||
Provider::OpenAI => write!(f, "OpenAI"),
|
||||
Provider::Claude => write!(f, "Claude"),
|
||||
Provider::Github => write!(f, "Github"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::providers::openai::types::ChatCompletionsRequest;
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
pub mod types;
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
pub use crate::providers::openai::types::{Choice, Message, Usage};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GroqRequest {
|
||||
#[serde(flatten)]
|
||||
pub base: ChatCompletionsRequest,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GroqResponse {
|
||||
#[serde(flatten)]
|
||||
pub base: ChatCompletionsResponse,
|
||||
}
|
||||
|
|
@ -1,3 +1,2 @@
|
|||
pub mod deepseek;
|
||||
pub mod groq;
|
||||
pub mod openai;
|
||||
|
|
|
|||
|
|
@ -1,18 +1,22 @@
|
|||
use std::fmt::Display;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_json::{Value};
|
||||
use serde_with::skip_serializing_none;
|
||||
use std::convert::TryFrom;
|
||||
use std::str;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::Provider;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OpenAIError {
|
||||
#[error("json error: {0}")]
|
||||
JsonParseError(#[from] serde_json::Error),
|
||||
#[error("utf8 parsing error: {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
#[error("unsupported provider: {provider}")]
|
||||
UnsupportedProvider { provider: String },
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, OpenAIError>;
|
||||
|
|
@ -117,6 +121,30 @@ impl TryFrom<&[u8]> for ChatCompletionsResponse {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for ChatCompletionsResponse {
|
||||
type Error = OpenAIError;
|
||||
|
||||
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
|
||||
// Use input.provider as needed, if necessary
|
||||
serde_json::from_slice(input.0).map_err(OpenAIError::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl ChatCompletionsRequest {
|
||||
pub fn to_bytes(&self, provider: Provider) -> Result<Vec<u8>> {
|
||||
match provider {
|
||||
Provider::OpenAI
|
||||
| Provider::Mistral
|
||||
| Provider::Groq
|
||||
| Provider::Gemini
|
||||
| Provider::Claude => serde_json::to_vec(self).map_err(OpenAIError::from),
|
||||
_ => Err(OpenAIError::UnsupportedProvider {
|
||||
provider: provider.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Choice {
|
||||
|
|
@ -133,10 +161,17 @@ pub struct Usage {
|
|||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeltaMessage {
|
||||
pub role: Option<String>,
|
||||
pub content: Option<ContentType>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct StreamChoice {
|
||||
pub index: u32,
|
||||
pub delta: Message,
|
||||
pub delta: DeltaMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
|
|
@ -193,6 +228,16 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for SseChatCompletionIter<str::Lines<'a>> {
|
||||
type Error = OpenAIError;
|
||||
|
||||
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
|
||||
let s = std::str::from_utf8(input.0)?;
|
||||
// Use input.provider as needed
|
||||
Ok(SseChatCompletionIter::new(s.lines()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a [u8]> for SseChatCompletionIter<str::Lines<'a>> {
|
||||
type Error = OpenAIError;
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue