mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
more changes
This commit is contained in:
parent
22fde1f333
commit
25f1b72e7c
8 changed files with 122 additions and 41 deletions
5
crates/Cargo.lock
generated
5
crates/Cargo.lock
generated
|
|
@ -1076,6 +1076,7 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
|
|||
name = "hermesllm"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
|
|
@ -1642,9 +1643,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.22"
|
||||
version = "0.4.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
|
||||
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
|
||||
|
||||
[[package]]
|
||||
name = "mach2"
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ version = "0.1.0"
|
|||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
log = "0.4.27"
|
||||
serde = "1.0.219"
|
||||
serde_json = "1.0.140"
|
||||
serde_with = "3.12.0"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,46 @@
|
|||
//! hermesllm: A library for translating LLM API requests and responses
|
||||
//! between Mistral, Grok, Gemini, and OpenAI-compliant formats.
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
pub mod providers;
|
||||
|
||||
pub enum Provider {
|
||||
Mistral,
|
||||
Groq,
|
||||
Gemini,
|
||||
OpenAI,
|
||||
Claude,
|
||||
Github
|
||||
}
|
||||
|
||||
impl From<&str> for Provider {
|
||||
fn from(value: &str) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
"mistral" => Provider::Mistral,
|
||||
"groq" => Provider::Groq,
|
||||
"gemini" => Provider::Gemini,
|
||||
"openai" => Provider::OpenAI,
|
||||
"claude" => Provider::Claude,
|
||||
"github" => Provider::Github,
|
||||
_ => panic!("Unknown provider: {}", value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Provider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Provider::Mistral => write!(f, "Mistral"),
|
||||
Provider::Groq => write!(f, "Groq"),
|
||||
Provider::Gemini => write!(f, "Gemini"),
|
||||
Provider::OpenAI => write!(f, "OpenAI"),
|
||||
Provider::Claude => write!(f, "Claude"),
|
||||
Provider::Github => write!(f, "Github"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::providers::openai::types::ChatCompletionsRequest;
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
pub mod types;
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
pub use crate::providers::openai::types::{Choice, Message, Usage};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GroqRequest {
|
||||
#[serde(flatten)]
|
||||
pub base: ChatCompletionsRequest,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GroqResponse {
|
||||
#[serde(flatten)]
|
||||
pub base: ChatCompletionsResponse,
|
||||
}
|
||||
|
|
@ -1,3 +1,2 @@
|
|||
pub mod deepseek;
|
||||
pub mod groq;
|
||||
pub mod openai;
|
||||
|
|
|
|||
|
|
@ -1,18 +1,22 @@
|
|||
use std::fmt::Display;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_json::{Value};
|
||||
use serde_with::skip_serializing_none;
|
||||
use std::convert::TryFrom;
|
||||
use std::str;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::Provider;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OpenAIError {
|
||||
#[error("json error: {0}")]
|
||||
JsonParseError(#[from] serde_json::Error),
|
||||
#[error("utf8 parsing error: {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
#[error("unsupported provider: {provider}")]
|
||||
UnsupportedProvider { provider: String },
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, OpenAIError>;
|
||||
|
|
@ -117,6 +121,30 @@ impl TryFrom<&[u8]> for ChatCompletionsResponse {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for ChatCompletionsResponse {
|
||||
type Error = OpenAIError;
|
||||
|
||||
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
|
||||
// Use input.provider as needed, if necessary
|
||||
serde_json::from_slice(input.0).map_err(OpenAIError::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl ChatCompletionsRequest {
|
||||
pub fn to_bytes(&self, provider: Provider) -> Result<Vec<u8>> {
|
||||
match provider {
|
||||
Provider::OpenAI
|
||||
| Provider::Mistral
|
||||
| Provider::Groq
|
||||
| Provider::Gemini
|
||||
| Provider::Claude => serde_json::to_vec(self).map_err(OpenAIError::from),
|
||||
_ => Err(OpenAIError::UnsupportedProvider {
|
||||
provider: provider.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Choice {
|
||||
|
|
@ -133,10 +161,17 @@ pub struct Usage {
|
|||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeltaMessage {
|
||||
pub role: Option<String>,
|
||||
pub content: Option<ContentType>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct StreamChoice {
|
||||
pub index: u32,
|
||||
pub delta: Message,
|
||||
pub delta: DeltaMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
|
|
@ -193,6 +228,16 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for SseChatCompletionIter<str::Lines<'a>> {
|
||||
type Error = OpenAIError;
|
||||
|
||||
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
|
||||
let s = std::str::from_utf8(input.0)?;
|
||||
// Use input.provider as needed
|
||||
Ok(SseChatCompletionIter::new(s.lines()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a [u8]> for SseChatCompletionIter<str::Lines<'a>> {
|
||||
type Error = OpenAIError;
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ use hermesllm::providers::openai::types::{ChatCompletionsRequest, SseChatComplet
|
|||
use hermesllm::providers::openai::types::{
|
||||
ChatCompletionsResponse, ContentType, Message, StreamOptions,
|
||||
};
|
||||
use hermesllm::Provider;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
|
|
@ -338,13 +339,6 @@ impl HttpContext for StreamContext {
|
|||
model_name.unwrap_or(&"None".to_string()),
|
||||
);
|
||||
|
||||
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
|
||||
|
||||
debug!(
|
||||
"on_http_request_body: request body: {}",
|
||||
chat_completion_request_str
|
||||
);
|
||||
|
||||
if deserialized_body.stream.unwrap_or_default() {
|
||||
self.streaming_response = true;
|
||||
}
|
||||
|
|
@ -379,7 +373,23 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
|
||||
// convert chat completion request to llm provider specific request
|
||||
let deserialized_body_bytes = match deserialized_body.to_bytes(hermesllm::Provider::OpenAI)
|
||||
{
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize request body: {}", e);
|
||||
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"on_http_request_body: request body string: {}",
|
||||
String::from_utf8_lossy(&deserialized_body_bytes)
|
||||
);
|
||||
|
||||
self.set_http_request_body(0, body_size, &deserialized_body_bytes);
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
|
|
@ -534,9 +544,12 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let llm_provider_str = self.llm_provider().provider_interface.to_string();
|
||||
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
|
||||
|
||||
if self.streaming_response {
|
||||
let chat_completions_chunk_response_events =
|
||||
match SseChatCompletionIter::try_from(body.as_slice()) {
|
||||
match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) {
|
||||
Ok(events) => events,
|
||||
Err(e) => {
|
||||
warn!("could not parse response: {}", e);
|
||||
|
|
@ -580,14 +593,18 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
let chat_completions_response: ChatCompletionsResponse =
|
||||
match serde_json::from_slice(body.as_slice()) {
|
||||
let chat_completions_response =
|
||||
match ChatCompletionsResponse::try_from((body.as_slice(), &hermes_llm_provider)) {
|
||||
Ok(de) => de,
|
||||
Err(err) => {
|
||||
info!(
|
||||
"non chat-completion compliant response received err: {}, body: {:?}",
|
||||
err,
|
||||
String::from_utf8(body)
|
||||
Err(e) => {
|
||||
warn!("could not parse response: {}", e);
|
||||
debug!(
|
||||
"on_http_response_body: response body: {}",
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::OpenAIPError(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue