mirror of
https://github.com/katanemo/plano.git
synced 2026-06-23 15:38:07 +02:00
refactored changes to support enum dispatch
This commit is contained in:
parent
7253a0f203
commit
327b29ec6f
10 changed files with 526 additions and 478 deletions
|
|
@ -5,7 +5,10 @@ use std::collections::HashMap;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use crate::{providers::ProviderRequestError, ConversionMode, ProviderRequest};
|
|
||||||
|
|
||||||
|
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||||
|
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, TokenUsage};
|
||||||
use super::ApiDefinition;
|
use super::ApiDefinition;
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
@ -127,8 +130,6 @@ pub struct Message {
|
||||||
pub tool_call_id: Option<String>,
|
pub tool_call_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#[skip_serializing_none]
|
#[skip_serializing_none]
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub struct ResponseMessage {
|
pub struct ResponseMessage {
|
||||||
|
|
@ -449,9 +450,92 @@ pub struct StreamOptions {
|
||||||
pub include_usage: Option<bool>,
|
pub include_usage: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// ============================================================================
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
/// OpenAI Provider Request Wrapper
|
pub struct ModelDetail {
|
||||||
/// ============================================================================
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: usize,
|
||||||
|
pub owned_by: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub enum ModelObject {
|
||||||
|
#[serde(rename = "list")]
|
||||||
|
List,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Models {
|
||||||
|
pub object: ModelObject,
|
||||||
|
pub data: Vec<ModelDetail>,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Error type for streaming operations
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum OpenAIStreamError {
|
||||||
|
#[error("JSON parsing error: {0}")]
|
||||||
|
JsonError(#[from] serde_json::Error),
|
||||||
|
#[error("UTF-8 parsing error: {0}")]
|
||||||
|
Utf8Error(#[from] std::str::Utf8Error),
|
||||||
|
#[error("Invalid streaming data: {0}")]
|
||||||
|
InvalidStreamingData(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum OpenAIError {
|
||||||
|
#[error("json error: {0}")]
|
||||||
|
JsonParseError(#[from] serde_json::Error),
|
||||||
|
#[error("utf8 parsing error: {0}")]
|
||||||
|
Utf8Error(#[from] std::str::Utf8Error),
|
||||||
|
#[error("invalid streaming data err {source}, data: {data}")]
|
||||||
|
InvalidStreamingData {
|
||||||
|
source: serde_json::Error,
|
||||||
|
data: String,
|
||||||
|
},
|
||||||
|
#[error("unsupported provider: {provider}")]
|
||||||
|
UnsupportedProvider { provider: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
/// Trait Implementations
|
||||||
|
/// ===========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
/// Parameterized conversion for ChatCompletionsRequest
|
||||||
|
impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
||||||
|
type Error = OpenAIStreamError;
|
||||||
|
|
||||||
|
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||||
|
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parameterized conversion for ChatCompletionsResponse
|
||||||
|
impl TryFrom<&[u8]> for ChatCompletionsResponse {
|
||||||
|
type Error = OpenAIStreamError;
|
||||||
|
|
||||||
|
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||||
|
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implementation of TokenUsage for OpenAI Usage type
|
||||||
|
impl TokenUsage for Usage {
|
||||||
|
fn completion_tokens(&self) -> usize {
|
||||||
|
self.completion_tokens as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prompt_tokens(&self) -> usize {
|
||||||
|
self.prompt_tokens as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
fn total_tokens(&self) -> usize {
|
||||||
|
self.total_tokens as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implementation of ProviderRequest for ChatCompletionsRequest
|
||||||
impl ProviderRequest for ChatCompletionsRequest {
|
impl ProviderRequest for ChatCompletionsRequest {
|
||||||
fn model(&self) -> &str {
|
fn model(&self) -> &str {
|
||||||
&self.model
|
&self.model
|
||||||
|
|
@ -493,144 +577,29 @@ impl ProviderRequest for ChatCompletionsRequest {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_provider_bytes(&self, mode: ConversionMode) -> Result<Vec<u8>, ProviderRequestError> {
|
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
|
||||||
match mode {
|
serde_json::to_vec(&self).map_err(|e| ProviderRequestError {
|
||||||
ConversionMode::Compatible | ConversionMode::Passthrough => {
|
message: format!("Failed to serialize OpenAI request: {}", e),
|
||||||
serde_json::to_vec(&self).map_err(|e| ProviderRequestError {
|
source: Some(Box::new(e)),
|
||||||
message: format!("Failed to serialize OpenAI request: {}", e),
|
})
|
||||||
source: Some(Box::new(e)),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================================
|
/// Implementation of ProviderResponse for ChatCompletionsResponse
|
||||||
// STREAMING SUPPORT
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
use crate::providers::traits::{ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, TokenUsage};
|
|
||||||
|
|
||||||
// Direct implementation of ProviderResponse on ChatCompletionsResponse
|
|
||||||
impl ProviderResponse for ChatCompletionsResponse {
|
impl ProviderResponse for ChatCompletionsResponse {
|
||||||
fn usage(&self) -> Option<&dyn TokenUsage> {
|
fn usage(&self) -> Option<&dyn TokenUsage> {
|
||||||
Some(&self.usage)
|
Some(&self.usage)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================
|
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||||
// PARAMETERIZED CONVERSIONS FOR PROVIDER FUNCTIONS
|
Some((
|
||||||
// ============================================================================
|
self.usage.prompt_tokens(),
|
||||||
|
self.usage.completion_tokens(),
|
||||||
use crate::providers::ProviderId;
|
self.usage.total_tokens(),
|
||||||
|
))
|
||||||
/// Parameterized conversion for ChatCompletionsRequest
|
|
||||||
impl TryFrom<(&[u8], &ProviderId)> for ChatCompletionsRequest {
|
|
||||||
type Error = OpenAIStreamError;
|
|
||||||
|
|
||||||
fn try_from((bytes, _provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
|
||||||
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parameterized conversion for ChatCompletionsResponse
|
|
||||||
impl TryFrom<(&[u8], &ProviderId)> for ChatCompletionsResponse {
|
|
||||||
type Error = OpenAIStreamError;
|
|
||||||
|
|
||||||
fn try_from((bytes, _provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
|
||||||
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse
|
|
||||||
impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
|
||||||
fn content_delta(&self) -> Option<&str> {
|
|
||||||
self.choices
|
|
||||||
.first()
|
|
||||||
.and_then(|choice| choice.delta.content.as_deref())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_final(&self) -> bool {
|
|
||||||
self.choices
|
|
||||||
.first()
|
|
||||||
.map(|choice| choice.finish_reason.is_some())
|
|
||||||
.unwrap_or(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn role(&self) -> Option<&str> {
|
|
||||||
self.choices
|
|
||||||
.first()
|
|
||||||
.and_then(|choice| choice.delta.role.as_ref().map(|r| match r {
|
|
||||||
Role::System => "system",
|
|
||||||
Role::User => "user",
|
|
||||||
Role::Assistant => "assistant",
|
|
||||||
Role::Tool => "tool",
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implementation of TokenUsage for OpenAI Usage type
|
|
||||||
impl TokenUsage for Usage {
|
|
||||||
fn completion_tokens(&self) -> usize {
|
|
||||||
self.completion_tokens as usize
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prompt_tokens(&self) -> usize {
|
|
||||||
self.prompt_tokens as usize
|
|
||||||
}
|
|
||||||
|
|
||||||
fn total_tokens(&self) -> usize {
|
|
||||||
self.total_tokens as usize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ModelDetail {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
|
||||||
pub created: usize,
|
|
||||||
pub owned_by: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub enum ModelObject {
|
|
||||||
#[serde(rename = "list")]
|
|
||||||
List,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct Models {
|
|
||||||
pub object: ModelObject,
|
|
||||||
pub data: Vec<ModelDetail>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error type for streaming operations
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
|
||||||
pub enum OpenAIStreamError {
|
|
||||||
#[error("JSON parsing error: {0}")]
|
|
||||||
JsonError(#[from] serde_json::Error),
|
|
||||||
#[error("UTF-8 parsing error: {0}")]
|
|
||||||
Utf8Error(#[from] std::str::Utf8Error),
|
|
||||||
#[error("Invalid streaming data: {0}")]
|
|
||||||
InvalidStreamingData(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum OpenAIError {
|
|
||||||
#[error("json error: {0}")]
|
|
||||||
JsonParseError(#[from] serde_json::Error),
|
|
||||||
#[error("utf8 parsing error: {0}")]
|
|
||||||
Utf8Error(#[from] std::str::Utf8Error),
|
|
||||||
#[error("invalid streaming data err {source}, data: {data}")]
|
|
||||||
InvalidStreamingData {
|
|
||||||
source: serde_json::Error,
|
|
||||||
data: String,
|
|
||||||
},
|
|
||||||
#[error("unsupported provider: {provider}")]
|
|
||||||
UnsupportedProvider { provider: String },
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/// SSE-based streaming iterator for OpenAI chat completions
|
/// SSE-based streaming iterator for OpenAI chat completions
|
||||||
/// Implements ProviderStreamResponseIter directly
|
/// Implements ProviderStreamResponseIter directly
|
||||||
pub struct SseChatCompletionIter<I>
|
pub struct SseChatCompletionIter<I>
|
||||||
|
|
@ -696,6 +665,34 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse
|
||||||
|
impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
||||||
|
fn content_delta(&self) -> Option<&str> {
|
||||||
|
self.choices
|
||||||
|
.first()
|
||||||
|
.and_then(|choice| choice.delta.content.as_deref())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_final(&self) -> bool {
|
||||||
|
self.choices
|
||||||
|
.first()
|
||||||
|
.map(|choice| choice.finish_reason.is_some())
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn role(&self) -> Option<&str> {
|
||||||
|
self.choices
|
||||||
|
.first()
|
||||||
|
.and_then(|choice| choice.delta.role.as_ref().map(|r| match r {
|
||||||
|
Role::System => "system",
|
||||||
|
Role::User => "user",
|
||||||
|
Role::Assistant => "assistant",
|
||||||
|
Role::Tool => "tool",
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
||||||
|
|
@ -6,13 +6,10 @@ pub mod apis;
|
||||||
pub mod clients;
|
pub mod clients;
|
||||||
|
|
||||||
// Re-export important types and traits
|
// Re-export important types and traits
|
||||||
pub use providers::{
|
pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError};
|
||||||
ProviderId, ConversionMode,
|
pub use providers::response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, ProviderResponseError, TokenUsage, try_streaming_from_bytes};
|
||||||
ProviderRequest, ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter,
|
pub use providers::id::ProviderId;
|
||||||
TokenUsage,
|
pub use providers::adapters::{has_compatible_api, supported_apis};
|
||||||
try_request_from_bytes, try_response_from_bytes, try_streaming_from_bytes,
|
|
||||||
has_compatible_api, supported_apis
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
@ -58,7 +55,7 @@ mod tests {
|
||||||
]
|
]
|
||||||
}"#;
|
}"#;
|
||||||
|
|
||||||
let result = try_request_from_bytes(json_request.as_bytes(), &ProviderId::OpenAI);
|
let result: Result<ProviderRequestType, std::io::Error> = ProviderRequestType::try_from(json_request.as_bytes());
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
let request = result.unwrap();
|
let request = result.unwrap();
|
||||||
|
|
@ -74,7 +71,7 @@ mod tests {
|
||||||
data: [DONE]
|
data: [DONE]
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let result = try_streaming_from_bytes(sse_data.as_bytes(), &ProviderId::OpenAI, ConversionMode::Passthrough);
|
let result = try_streaming_from_bytes(sse_data.as_bytes(), &ProviderId::OpenAI);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
let mut streaming_response = result.unwrap();
|
let mut streaming_response = result.unwrap();
|
||||||
|
|
|
||||||
39
crates/hermesllm/src/providers/adapters.rs
Normal file
39
crates/hermesllm/src/providers/adapters.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
use crate::providers::id::ProviderId;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum AdapterType {
|
||||||
|
OpenAICompatible,
|
||||||
|
// Future: Claude, Gemini, etc.
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provider adapter configuration
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ProviderConfig {
|
||||||
|
pub supported_apis: &'static [&'static str],
|
||||||
|
pub adapter_type: AdapterType,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if provider has compatible API
|
||||||
|
pub fn has_compatible_api(provider_id: &ProviderId, api_path: &str) -> bool {
|
||||||
|
let config = get_provider_config(provider_id);
|
||||||
|
config.supported_apis.iter().any(|&supported| supported == api_path)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get supported APIs for provider
|
||||||
|
pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> {
|
||||||
|
let config = get_provider_config(provider_id);
|
||||||
|
config.supported_apis.to_vec()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get provider configuration
|
||||||
|
pub fn get_provider_config(provider_id: &ProviderId) -> ProviderConfig {
|
||||||
|
match provider_id {
|
||||||
|
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||||
|
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||||
|
ProviderConfig {
|
||||||
|
supported_apis: &["/v1/chat/completions"],
|
||||||
|
adapter_type: AdapterType::OpenAICompatible,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
45
crates/hermesllm/src/providers/id.rs
Normal file
45
crates/hermesllm/src/providers/id.rs
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
/// Provider identifier enum - simple enum for identifying providers
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub enum ProviderId {
|
||||||
|
OpenAI,
|
||||||
|
Mistral,
|
||||||
|
Deepseek,
|
||||||
|
Groq,
|
||||||
|
Gemini,
|
||||||
|
Claude,
|
||||||
|
GitHub,
|
||||||
|
Arch,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for ProviderId {
|
||||||
|
fn from(value: &str) -> Self {
|
||||||
|
match value.to_lowercase().as_str() {
|
||||||
|
"openai" => ProviderId::OpenAI,
|
||||||
|
"mistral" => ProviderId::Mistral,
|
||||||
|
"deepseek" => ProviderId::Deepseek,
|
||||||
|
"groq" => ProviderId::Groq,
|
||||||
|
"gemini" => ProviderId::Gemini,
|
||||||
|
"claude" => ProviderId::Claude,
|
||||||
|
"github" => ProviderId::GitHub,
|
||||||
|
"arch" => ProviderId::Arch,
|
||||||
|
_ => panic!("Unknown provider: {}", value),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for ProviderId {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
ProviderId::OpenAI => write!(f, "OpenAI"),
|
||||||
|
ProviderId::Mistral => write!(f, "Mistral"),
|
||||||
|
ProviderId::Deepseek => write!(f, "Deepseek"),
|
||||||
|
ProviderId::Groq => write!(f, "Groq"),
|
||||||
|
ProviderId::Gemini => write!(f, "Gemini"),
|
||||||
|
ProviderId::Claude => write!(f, "Claude"),
|
||||||
|
ProviderId::GitHub => write!(f, "GitHub"),
|
||||||
|
ProviderId::Arch => write!(f, "Arch"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -2,57 +2,13 @@
|
||||||
//!
|
//!
|
||||||
//! This module contains provider-specific implementations that handle
|
//! This module contains provider-specific implementations that handle
|
||||||
//! request/response conversion for different LLM service APIs.
|
//! request/response conversion for different LLM service APIs.
|
||||||
|
//!
|
||||||
|
pub mod id;
|
||||||
|
pub mod request;
|
||||||
|
pub mod response;
|
||||||
|
pub mod adapters;
|
||||||
|
|
||||||
pub mod traits;
|
pub use id::ProviderId;
|
||||||
pub mod openai;
|
pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ;
|
||||||
|
pub use response::{ProviderResponseType, ProviderStreamResponseType, ProviderResponse, ProviderStreamResponse, TokenUsage };
|
||||||
// Re-export the main interfaces
|
pub use adapters::*;
|
||||||
pub use traits::*;
|
|
||||||
// Note: OpenAIProvider has been deprecated in favor of function-based approach
|
|
||||||
// OpenAI functionality is accessed through openai::builder and openai::types modules
|
|
||||||
|
|
||||||
use std::fmt::Display;
|
|
||||||
|
|
||||||
/// Provider identifier enum - simple enum for identifying providers
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
|
||||||
pub enum ProviderId {
|
|
||||||
OpenAI,
|
|
||||||
Mistral,
|
|
||||||
Deepseek,
|
|
||||||
Groq,
|
|
||||||
Gemini,
|
|
||||||
Claude,
|
|
||||||
GitHub,
|
|
||||||
Arch,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<&str> for ProviderId {
|
|
||||||
fn from(value: &str) -> Self {
|
|
||||||
match value.to_lowercase().as_str() {
|
|
||||||
"openai" => ProviderId::OpenAI,
|
|
||||||
"mistral" => ProviderId::Mistral,
|
|
||||||
"deepseek" => ProviderId::Deepseek,
|
|
||||||
"groq" => ProviderId::Groq,
|
|
||||||
"gemini" => ProviderId::Gemini,
|
|
||||||
"claude" => ProviderId::Claude,
|
|
||||||
"github" => ProviderId::GitHub,
|
|
||||||
"arch" => ProviderId::Arch,
|
|
||||||
_ => panic!("Unknown provider: {}", value),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for ProviderId {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
ProviderId::OpenAI => write!(f, "OpenAI"),
|
|
||||||
ProviderId::Mistral => write!(f, "Mistral"),
|
|
||||||
ProviderId::Deepseek => write!(f, "Deepseek"),
|
|
||||||
ProviderId::Groq => write!(f, "Groq"),
|
|
||||||
ProviderId::Gemini => write!(f, "Gemini"),
|
|
||||||
ProviderId::Claude => write!(f, "Claude"),
|
|
||||||
ProviderId::GitHub => write!(f, "GitHub"),
|
|
||||||
ProviderId::Arch => write!(f, "Arch"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +0,0 @@
|
||||||
// Re-export the main types and builder functionality
|
|
||||||
pub use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse};
|
|
||||||
|
|
||||||
// Note: The OpenAIProvider struct has been deprecated in favor of the function-based approach in traits.rs
|
|
||||||
// All provider functionality is now accessed through try_request_from_bytes, try_response_from_bytes, etc.
|
|
||||||
124
crates/hermesllm/src/providers/request.rs
Normal file
124
crates/hermesllm/src/providers/request.rs
Normal file
|
|
@ -0,0 +1,124 @@
|
||||||
|
|
||||||
|
use crate::apis::openai::ChatCompletionsRequest;
|
||||||
|
use super::{ProviderId, get_provider_config, AdapterType};
|
||||||
|
use std::error::Error;
|
||||||
|
use std::fmt;
|
||||||
|
pub enum ProviderRequestType {
|
||||||
|
ChatCompletionsRequest(ChatCompletionsRequest),
|
||||||
|
//MessagesRequest(MessagesRequest),
|
||||||
|
//add more request types here
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&[u8]> for ProviderRequestType {
|
||||||
|
type Error = std::io::Error;
|
||||||
|
|
||||||
|
// if passing bytes without provider id we assume the request is in OpenAI format
|
||||||
|
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||||
|
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<(&[u8], &ProviderId)> for ProviderRequestType {
|
||||||
|
type Error = std::io::Error;
|
||||||
|
|
||||||
|
fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
||||||
|
let config = get_provider_config(provider_id);
|
||||||
|
match config.adapter_type {
|
||||||
|
AdapterType::OpenAICompatible => {
|
||||||
|
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||||
|
}
|
||||||
|
// Future: handle other adapter types like Claude
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait ProviderRequest: Send + Sync {
|
||||||
|
/// Extract the model name from the request
|
||||||
|
fn model(&self) -> &str;
|
||||||
|
|
||||||
|
/// Set the model name for the request
|
||||||
|
fn set_model(&mut self, model: String);
|
||||||
|
|
||||||
|
/// Check if this is a streaming request
|
||||||
|
fn is_streaming(&self) -> bool;
|
||||||
|
|
||||||
|
/// Set streaming options (e.g., include_usage)
|
||||||
|
fn set_streaming_options(&mut self);
|
||||||
|
|
||||||
|
/// Extract text content from messages for token counting
|
||||||
|
fn extract_messages_text(&self) -> String;
|
||||||
|
|
||||||
|
/// Extract the user message for tracing/logging purposes
|
||||||
|
fn extract_user_message(&self) -> Option<String>;
|
||||||
|
|
||||||
|
/// Convert the request to bytes for transmission
|
||||||
|
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderRequest for ProviderRequestType {
|
||||||
|
fn model(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
Self::ChatCompletionsRequest(r) => r.model(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_model(&mut self, model: String) {
|
||||||
|
match self {
|
||||||
|
Self::ChatCompletionsRequest(r) => r.set_model(model),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_streaming(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::ChatCompletionsRequest(r) => r.is_streaming(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_streaming_options(&mut self) {
|
||||||
|
match self {
|
||||||
|
Self::ChatCompletionsRequest(r) => r.set_streaming_options(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_messages_text(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Self::ChatCompletionsRequest(r) => r.extract_messages_text(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_user_message(&self) -> Option<String> {
|
||||||
|
match self {
|
||||||
|
Self::ChatCompletionsRequest(r) => r.extract_user_message(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
|
||||||
|
match self {
|
||||||
|
Self::ChatCompletionsRequest(r) => r.to_bytes(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Error types for provider operations
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ProviderRequestError {
|
||||||
|
pub message: String,
|
||||||
|
pub source: Option<Box<dyn Error + Send + Sync>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for ProviderRequestError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "Provider request error: {}", self.message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for ProviderRequestError {
|
||||||
|
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||||
|
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||||
|
}
|
||||||
|
}
|
||||||
142
crates/hermesllm/src/providers/response.rs
Normal file
142
crates/hermesllm/src/providers/response.rs
Normal file
|
|
@ -0,0 +1,142 @@
|
||||||
|
use std::error::Error;
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
use crate::apis::openai::ChatCompletionsResponse;
|
||||||
|
use crate::apis::openai::ChatCompletionsStreamResponse;
|
||||||
|
use crate::providers::id::ProviderId;
|
||||||
|
use crate::providers::adapters::{get_provider_config, AdapterType};
|
||||||
|
|
||||||
|
pub enum ProviderResponseType {
|
||||||
|
ChatCompletionsResponse(ChatCompletionsResponse),
|
||||||
|
//MessagesResponse(MessagesResponse),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum ProviderStreamResponseType {
|
||||||
|
ChatCompletionsStreamResponse(ChatCompletionsStreamResponse),
|
||||||
|
//MessagesStreamResponse(MessagesStreamMessage),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<(&[u8], ProviderId)> for ProviderResponseType {
|
||||||
|
type Error = std::io::Error;
|
||||||
|
|
||||||
|
fn try_from((bytes, provider_id): (&[u8], ProviderId)) -> Result<Self, Self::Error> {
|
||||||
|
let config = get_provider_config(&provider_id);
|
||||||
|
match config.adapter_type {
|
||||||
|
AdapterType::OpenAICompatible => {
|
||||||
|
let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
Ok(ProviderResponseType::ChatCompletionsResponse(chat_completions_response))
|
||||||
|
}
|
||||||
|
// Future: handle other adapter types like Claude
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub trait ProviderResponse: Send + Sync {
|
||||||
|
/// Get usage information if available - returns dynamic trait object
|
||||||
|
fn usage(&self) -> Option<&dyn TokenUsage>;
|
||||||
|
|
||||||
|
/// Extract token counts for metrics
|
||||||
|
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||||
|
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait ProviderStreamResponse: Send + Sync {
|
||||||
|
/// Get the content delta for this chunk
|
||||||
|
fn content_delta(&self) -> Option<&str>;
|
||||||
|
|
||||||
|
/// Check if this is the final chunk in the stream
|
||||||
|
fn is_final(&self) -> bool;
|
||||||
|
|
||||||
|
/// Get role information if available
|
||||||
|
fn role(&self) -> Option<&str>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trait for streaming response iterators
|
||||||
|
pub trait ProviderStreamResponseIter: Iterator<Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>> + Send + Sync {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl ProviderResponse for ProviderResponseType {
|
||||||
|
fn usage(&self) -> Option<&dyn TokenUsage> {
|
||||||
|
match self {
|
||||||
|
ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(),
|
||||||
|
// Future: ProviderResponseType::MessagesResponse(resp) => resp.usage(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||||
|
match self {
|
||||||
|
ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(),
|
||||||
|
// Future: ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderStreamResponse for ProviderStreamResponseType {
|
||||||
|
fn content_delta(&self) -> Option<&str> {
|
||||||
|
match self {
|
||||||
|
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.content_delta(),
|
||||||
|
// Future: ProviderStreamResponseType::MessagesStreamResponse(resp) => resp.content_delta(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_final(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.is_final(),
|
||||||
|
// Future: ProviderStreamResponseType::MessagesStreamResponse(resp) => resp.is_final(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn role(&self) -> Option<&str> {
|
||||||
|
match self {
|
||||||
|
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.role(),
|
||||||
|
// Future: ProviderStreamResponseType::MessagesStreamResponse(resp) => resp.role(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trait for token usage information
|
||||||
|
pub trait TokenUsage {
|
||||||
|
fn completion_tokens(&self) -> usize;
|
||||||
|
fn prompt_tokens(&self) -> usize;
|
||||||
|
fn total_tokens(&self) -> usize;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ProviderResponseError {
|
||||||
|
pub message: String,
|
||||||
|
pub source: Option<Box<dyn Error + Send + Sync>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl fmt::Display for ProviderResponseError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "Provider response error: {}", self.message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for ProviderResponseError {
|
||||||
|
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||||
|
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create streaming response using provider ID - returns clean ProviderStreamResponseIter trait object
|
||||||
|
pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result<Box<dyn ProviderStreamResponseIter>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let config = get_provider_config(provider_id);
|
||||||
|
|
||||||
|
match config.adapter_type {
|
||||||
|
AdapterType::OpenAICompatible => {
|
||||||
|
// Parse SSE (Server-Sent Events) streaming data
|
||||||
|
let s = std::str::from_utf8(bytes)?;
|
||||||
|
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||||
|
let iter = crate::apis::openai::SseChatCompletionIter::new(lines.into_iter());
|
||||||
|
|
||||||
|
// Return the iterator directly - it implements ProviderStreamResponseIter
|
||||||
|
Ok(Box::new(iter))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,247 +0,0 @@
|
||||||
//! Provider traits for generic request/response handling
|
|
||||||
//!
|
|
||||||
//! This module defines the core traits that enable provider-agnostic
|
|
||||||
//! handling of LLM requests and responses in the gateway.
|
|
||||||
|
|
||||||
use std::error::Error;
|
|
||||||
use std::fmt;
|
|
||||||
|
|
||||||
/// Trait for provider-specific request types
|
|
||||||
pub trait ProviderRequest: Send + Sync {
|
|
||||||
/// Extract the model name from the request
|
|
||||||
fn model(&self) -> &str;
|
|
||||||
|
|
||||||
/// Set the model name for the request
|
|
||||||
fn set_model(&mut self, model: String);
|
|
||||||
|
|
||||||
/// Check if this is a streaming request
|
|
||||||
fn is_streaming(&self) -> bool;
|
|
||||||
|
|
||||||
/// Set streaming options (e.g., include_usage)
|
|
||||||
fn set_streaming_options(&mut self);
|
|
||||||
|
|
||||||
/// Extract text content from messages for token counting
|
|
||||||
fn extract_messages_text(&self) -> String;
|
|
||||||
|
|
||||||
/// Extract the user message for tracing/logging purposes
|
|
||||||
fn extract_user_message(&self) -> Option<String>;
|
|
||||||
|
|
||||||
/// Convert to provider-specific format
|
|
||||||
fn to_provider_bytes(&self, mode: ConversionMode) -> Result<Vec<u8>, ProviderRequestError>;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trait for provider-specific response types
|
|
||||||
pub trait ProviderResponse: Send + Sync {
|
|
||||||
/// Get usage information if available - returns dynamic trait object
|
|
||||||
fn usage(&self) -> Option<&dyn TokenUsage>;
|
|
||||||
|
|
||||||
/// Extract token counts for metrics
|
|
||||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
|
||||||
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trait for provider-specific streaming response types
|
|
||||||
pub trait ProviderStreamResponse: Send + Sync {
|
|
||||||
/// Get the content delta for this chunk
|
|
||||||
fn content_delta(&self) -> Option<&str>;
|
|
||||||
|
|
||||||
/// Check if this is the final chunk in the stream
|
|
||||||
fn is_final(&self) -> bool;
|
|
||||||
|
|
||||||
/// Get role information if available
|
|
||||||
fn role(&self) -> Option<&str>;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trait for streaming response iterators
|
|
||||||
pub trait ProviderStreamResponseIter: Iterator<Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>> + Send + Sync {
|
|
||||||
// No additional methods needed - just the Iterator constraint with proper bounds
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Conversion mode for provider requests/responses
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub enum ConversionMode {
|
|
||||||
/// Compatible: Convert between different provider formats to ensure compatibility
|
|
||||||
Compatible,
|
|
||||||
/// Passthrough: Pass requests/responses through with minimal modification
|
|
||||||
Passthrough,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trait for token usage information
|
|
||||||
pub trait TokenUsage {
|
|
||||||
fn completion_tokens(&self) -> usize;
|
|
||||||
fn prompt_tokens(&self) -> usize;
|
|
||||||
fn total_tokens(&self) -> usize;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// PROVIDER FUNCTIONS - NO TRAITS, JUST PARAMETERIZED CONVERSION
|
|
||||||
// ============================================================================
|
|
||||||
//
|
|
||||||
// ARCHITECTURAL DECISION: Function-based Provider API
|
|
||||||
//
|
|
||||||
// We chose this function-based approach over the original ProviderInterface trait
|
|
||||||
// for several critical reasons:
|
|
||||||
//
|
|
||||||
// 1. TRAIT OBJECT LIMITATION:
|
|
||||||
// - The original ProviderInterface had associated types (Request, Response, etc.)
|
|
||||||
// - Traits with associated types cannot be used as trait objects (Box<dyn ProviderInterface>)
|
|
||||||
// - This prevented dynamic provider selection at runtime based on request headers
|
|
||||||
// - Error: "the trait `ProviderInterface` cannot be made into an object"
|
|
||||||
//
|
|
||||||
// 2. DYNAMIC PROVIDER SELECTION REQUIREMENT:
|
|
||||||
// - The gateway needs to select providers dynamically based on incoming headers
|
|
||||||
// - Cannot know provider type at compile time - must dispatch at runtime
|
|
||||||
// - Need ability to return generic trait objects that work polymorphically
|
|
||||||
//
|
|
||||||
// 3. WRAPPER TYPE ELIMINATION:
|
|
||||||
// - Original design required wrapper types like OpenAIRequestWrapper, OpenAIResponseWrapper
|
|
||||||
// - User wanted to implement traits directly on concrete types (ChatCompletionsRequest, etc.)
|
|
||||||
// - Function-based approach allows direct trait implementations without wrappers
|
|
||||||
//
|
|
||||||
// 4. PARAMETERIZED CONVERSION PATTERN:
|
|
||||||
// - Follows existing codebase pattern: TryFrom<(&[u8], &ProviderId)>
|
|
||||||
// - Enables runtime provider selection while maintaining type safety
|
|
||||||
// - Single implementation can handle multiple OpenAI-compatible providers
|
|
||||||
//
|
|
||||||
// 5. TYPE ERASURE FOR GENERIC INTERFACE:
|
|
||||||
// - Functions return Box<dyn ProviderRequest/Response> - works as trait objects
|
|
||||||
// - stream_context.rs can work with generic interfaces without knowing concrete types
|
|
||||||
// - Maintains polymorphism while enabling dynamic dispatch
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
use crate::ProviderId;
|
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// PROVIDER ADAPTER REGISTRY (Organizational Enhancement)
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
/// Provider adapter configuration
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct ProviderConfig {
|
|
||||||
pub supported_apis: &'static [&'static str],
|
|
||||||
pub adapter_type: AdapterType,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum AdapterType {
|
|
||||||
OpenAICompatible,
|
|
||||||
// Future: Claude, Gemini, etc.
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get provider configuration
|
|
||||||
pub fn get_provider_config(provider_id: &ProviderId) -> ProviderConfig {
|
|
||||||
match provider_id {
|
|
||||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
|
||||||
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
|
||||||
ProviderConfig {
|
|
||||||
supported_apis: &["/v1/chat/completions"],
|
|
||||||
adapter_type: AdapterType::OpenAICompatible,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse request from bytes using provider ID - returns generic ProviderRequest trait object
|
|
||||||
pub fn try_request_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result<Box<dyn ProviderRequest>, ProviderRequestError> {
|
|
||||||
let config = get_provider_config(provider_id);
|
|
||||||
|
|
||||||
match config.adapter_type {
|
|
||||||
AdapterType::OpenAICompatible => {
|
|
||||||
let request = crate::apis::openai::ChatCompletionsRequest::try_from((bytes, provider_id))
|
|
||||||
.map_err(|e| ProviderRequestError {
|
|
||||||
message: format!("Failed to parse request: {}", e),
|
|
||||||
source: Some(Box::new(e)),
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Return as trait object - this enables polymorphic usage
|
|
||||||
// ChatCompletionsRequest implements ProviderRequest directly (no wrapper needed)
|
|
||||||
Ok(Box::new(request) as Box<dyn ProviderRequest>)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse response from bytes using provider ID - returns generic ProviderResponse trait object
|
|
||||||
pub fn try_response_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result<Box<dyn ProviderResponse>, ProviderResponseError> {
|
|
||||||
let config = get_provider_config(provider_id);
|
|
||||||
|
|
||||||
match config.adapter_type {
|
|
||||||
AdapterType::OpenAICompatible => {
|
|
||||||
// Parameterized conversion allows provider-specific response parsing
|
|
||||||
let response = crate::apis::openai::ChatCompletionsResponse::try_from((bytes, provider_id))
|
|
||||||
.map_err(|e| ProviderResponseError {
|
|
||||||
message: format!("Failed to parse response: {}", e),
|
|
||||||
source: Some(Box::new(e)),
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// ChatCompletionsResponse implements ProviderResponse directly - no wrapper needed!
|
|
||||||
Ok(Box::new(response) as Box<dyn ProviderResponse>)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create streaming response using provider ID - returns clean ProviderStreamResponseIter trait object
|
|
||||||
pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result<Box<dyn ProviderStreamResponseIter>, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let config = get_provider_config(provider_id);
|
|
||||||
|
|
||||||
match config.adapter_type {
|
|
||||||
AdapterType::OpenAICompatible => {
|
|
||||||
// Parse SSE (Server-Sent Events) streaming data
|
|
||||||
let s = std::str::from_utf8(bytes)?;
|
|
||||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
|
||||||
let iter = crate::apis::openai::SseChatCompletionIter::new(lines.into_iter());
|
|
||||||
|
|
||||||
// Return the iterator directly - it implements ProviderStreamResponseIter
|
|
||||||
Ok(Box::new(iter))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if provider has compatible API
|
|
||||||
pub fn has_compatible_api(provider_id: &ProviderId, api_path: &str) -> bool {
|
|
||||||
let config = get_provider_config(provider_id);
|
|
||||||
config.supported_apis.iter().any(|&supported| supported == api_path)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get supported APIs for provider
|
|
||||||
pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> {
|
|
||||||
let config = get_provider_config(provider_id);
|
|
||||||
config.supported_apis.to_vec()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Error types for provider operations
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ProviderRequestError {
|
|
||||||
pub message: String,
|
|
||||||
pub source: Option<Box<dyn Error + Send + Sync>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ProviderResponseError {
|
|
||||||
pub message: String,
|
|
||||||
pub source: Option<Box<dyn Error + Send + Sync>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for ProviderRequestError {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
write!(f, "Provider request error: {}", self.message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for ProviderResponseError {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
write!(f, "Provider response error: {}", self.message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Error for ProviderRequestError {
|
|
||||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
|
||||||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Error for ProviderResponseError {
|
|
||||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
|
||||||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -11,8 +11,8 @@ use common::stats::{IncrementingMetric, RecordingMetric};
|
||||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||||
use common::{ratelimit, routing, tokenizer};
|
use common::{ratelimit, routing, tokenizer};
|
||||||
use hermesllm::{
|
use hermesllm::{
|
||||||
try_request_from_bytes, try_response_from_bytes, try_streaming_from_bytes, ConversionMode,
|
try_streaming_from_bytes, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponse,
|
||||||
ProviderId,
|
ProviderResponseType,
|
||||||
};
|
};
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use log::{debug, info, warn};
|
use log::{debug, info, warn};
|
||||||
|
|
@ -300,20 +300,21 @@ impl HttpContext for StreamContext {
|
||||||
|
|
||||||
let provider_id = self.get_provider_id();
|
let provider_id = self.get_provider_id();
|
||||||
|
|
||||||
let mut deserialized_body = match try_request_from_bytes(&body_bytes, &provider_id) {
|
let mut deserialized_body =
|
||||||
Ok(deserialized) => deserialized,
|
match ProviderRequestType::try_from((&body_bytes[..], &provider_id)) {
|
||||||
Err(e) => {
|
Ok(deserialized) => deserialized,
|
||||||
debug!(
|
Err(e) => {
|
||||||
"on_http_request_body: request body: {}",
|
debug!(
|
||||||
String::from_utf8_lossy(&body_bytes)
|
"on_http_request_body: request body: {}",
|
||||||
);
|
String::from_utf8_lossy(&body_bytes)
|
||||||
self.send_server_error(
|
);
|
||||||
ServerError::LogicError(format!("Request parsing error: {}", e)),
|
self.send_server_error(
|
||||||
Some(StatusCode::BAD_REQUEST),
|
ServerError::LogicError(format!("Request parsing error: {}", e)),
|
||||||
);
|
Some(StatusCode::BAD_REQUEST),
|
||||||
return Action::Pause;
|
);
|
||||||
}
|
return Action::Pause;
|
||||||
};
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let model_name = match self.llm_provider.as_ref() {
|
let model_name = match self.llm_provider.as_ref() {
|
||||||
Some(llm_provider) => llm_provider.model.as_ref(),
|
Some(llm_provider) => llm_provider.model.as_ref(),
|
||||||
|
|
@ -388,18 +389,17 @@ impl HttpContext for StreamContext {
|
||||||
let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
|
let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||||
|
|
||||||
// Convert chat completion request to llm provider specific request using provider interface
|
// Convert chat completion request to llm provider specific request using provider interface
|
||||||
let deserialized_body_bytes =
|
let deserialized_body_bytes = match deserialized_body.to_bytes() {
|
||||||
match deserialized_body.to_provider_bytes(ConversionMode::Compatible) {
|
Ok(bytes) => bytes,
|
||||||
Ok(bytes) => bytes,
|
Err(e) => {
|
||||||
Err(e) => {
|
warn!("Failed to serialize request body: {}", e);
|
||||||
warn!("Failed to serialize request body: {}", e);
|
self.send_server_error(
|
||||||
self.send_server_error(
|
ServerError::LogicError(format!("Request serialization error: {}", e)),
|
||||||
ServerError::LogicError(format!("Request serialization error: {}", e)),
|
Some(StatusCode::BAD_REQUEST),
|
||||||
Some(StatusCode::BAD_REQUEST),
|
);
|
||||||
);
|
return Action::Pause;
|
||||||
return Action::Pause;
|
}
|
||||||
}
|
};
|
||||||
};
|
|
||||||
|
|
||||||
self.set_http_request_body(0, body_size, &deserialized_body_bytes);
|
self.set_http_request_body(0, body_size, &deserialized_body_bytes);
|
||||||
|
|
||||||
|
|
@ -572,7 +572,7 @@ impl HttpContext for StreamContext {
|
||||||
// Since all providers use OpenAI-compatible streaming format
|
// Since all providers use OpenAI-compatible streaming format
|
||||||
let provider_id = self.get_provider_id();
|
let provider_id = self.get_provider_id();
|
||||||
|
|
||||||
match try_streaming_from_bytes(&body, &provider_id, ConversionMode::Compatible) {
|
match try_streaming_from_bytes(&body, &provider_id) {
|
||||||
Ok(mut streaming_response) => {
|
Ok(mut streaming_response) => {
|
||||||
// Process each streaming chunk
|
// Process each streaming chunk
|
||||||
while let Some(chunk_result) = streaming_response.next() {
|
while let Some(chunk_result) = streaming_response.next() {
|
||||||
|
|
@ -630,8 +630,8 @@ impl HttpContext for StreamContext {
|
||||||
} else {
|
} else {
|
||||||
debug!("non streaming response");
|
debug!("non streaming response");
|
||||||
let provider_id = self.get_provider_id();
|
let provider_id = self.get_provider_id();
|
||||||
let response =
|
let response: ProviderResponseType =
|
||||||
match try_response_from_bytes(&body, &provider_id, ConversionMode::Compatible) {
|
match ProviderResponseType::try_from((&body[..], provider_id)) {
|
||||||
Ok(response) => response,
|
Ok(response) => response,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!(
|
warn!(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue