more refactoring changes, getting close

This commit is contained in:
Salman Paracha 2025-08-09 20:44:26 -07:00
parent 63f23efda4
commit 58028bb7ae
21 changed files with 542 additions and 217 deletions

View file

@ -13,14 +13,14 @@
//!
//! ```rust
//! use hermesllm::apis::{
//! AnthropicMessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage,
//! MessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage,
//! MessagesMessageContent, MessagesSystemPrompt,
//! };
//! use hermesllm::clients::TransformError;
//! use std::convert::TryInto;
//!
//! // Transform Anthropic to OpenAI
//! let anthropic_req = AnthropicMessagesRequest {
//! let anthropic_req = MessagesRequest {
//! model: "claude-3-sonnet".to_string(),
//! system: None,
//! messages: vec![],

View file

@ -46,18 +46,18 @@ mod tests {
#[test]
fn test_provider_instance_creation() {
let provider = Provider::new(ProviderId::OpenAI);
assert!(provider.has_compatible_api("/v1/chat/completions"));
assert!(!provider.has_compatible_api("/v1/embeddings"));
assert!(provider.interface().has_compatible_api("/v1/chat/completions"));
assert!(!provider.interface().has_compatible_api("/v1/embeddings"));
}
#[test]
fn test_conversion_mode() {
let provider = Provider::new(ProviderId::OpenAI);
let compatible_mode = provider.get_interface(false);
let compatible_mode = provider.interface().get_interface(false);
assert!(matches!(compatible_mode, ConversionMode::Compatible));
let passthrough_mode = provider.get_interface(true);
let passthrough_mode = provider.interface().get_interface(true);
assert!(matches!(passthrough_mode, ConversionMode::Passthrough));
}
}

View file

@ -0,0 +1,6 @@
//! Arch provider implementation
//!
//! Arch uses OpenAI-compatible API format
pub mod provider;
pub use provider::ArchProvider;

View file

@ -0,0 +1,40 @@
//! Arch provider implementation
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
/// Arch provider implementation
#[derive(Debug, Clone)]
pub struct ArchProvider;
impl ProviderInterface for ArchProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
matches!(api_path, "/v1/chat/completions")
}
fn supported_apis(&self) -> Vec<&'static str> {
vec!["/v1/chat/completions"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsRequest::try_from_bytes(bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match request.to_provider_bytes(provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -0,0 +1,7 @@
//! Claude provider implementation
//!
//! Claude will use a different API format in the future (/v1/messages)
//! For now, fallback to OpenAI-compatible format
pub mod provider;
pub use provider::ClaudeProvider;

View file

@ -0,0 +1,48 @@
//! Claude provider implementation
//!
//! TODO: Implement Claude-specific API format (/v1/messages) when needed
//! For now, uses OpenAI-compatible format as fallback
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
/// Claude provider implementation
#[derive(Debug, Clone)]
pub struct ClaudeProvider;
impl ProviderInterface for ClaudeProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
// TODO: Update when Claude API is fully implemented
matches!(api_path, "/v1/chat/completions" | "/v1/messages")
}
fn supported_apis(&self) -> Vec<&'static str> {
// TODO: Update when Claude API is fully implemented
vec!["/v1/messages"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Claude-specific request parsing
match ChatCompletionsRequest::try_from_bytes(bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Claude-specific response parsing
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Claude-specific request serialization
match request.to_provider_bytes(provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -0,0 +1,6 @@
//! Deepseek provider implementation
//!
//! Deepseek uses OpenAI-compatible API format
pub mod provider;
pub use provider::DeepseekProvider;

View file

@ -0,0 +1,40 @@
//! Deepseek provider implementation
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
/// Deepseek provider implementation
#[derive(Debug, Clone)]
pub struct DeepseekProvider;
impl ProviderInterface for DeepseekProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
matches!(api_path, "/v1/chat/completions")
}
fn supported_apis(&self) -> Vec<&'static str> {
vec!["/v1/chat/completions"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsRequest::try_from_bytes(bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match request.to_provider_bytes(provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -0,0 +1,7 @@
//! Gemini provider implementation
//!
//! Gemini will use a different API format in the future
//! For now, fallback to OpenAI-compatible format
pub mod provider;
pub use provider::GeminiProvider;

View file

@ -0,0 +1,48 @@
//! Gemini provider implementation
//!
//! TODO: Implement Gemini-specific API format when needed
//! For now, uses OpenAI-compatible format as fallback
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
/// Gemini provider implementation
#[derive(Debug, Clone)]
pub struct GeminiProvider;
impl ProviderInterface for GeminiProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
// TODO: Update when Gemini API is fully implemented
matches!(api_path, "/v1/chat/completions" | "/v1/models")
}
fn supported_apis(&self) -> Vec<&'static str> {
// TODO: Update when Gemini API is fully implemented
vec!["/v1/models"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Gemini-specific request parsing
match ChatCompletionsRequest::try_from_bytes(bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Gemini-specific response parsing
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Gemini-specific request serialization
match request.to_provider_bytes(provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -0,0 +1,7 @@
//! GitHub provider implementation
//!
//! GitHub will use a different API format in the future (/models)
//! For now, fallback to OpenAI-compatible format
pub mod provider;
pub use provider::GitHubProvider;

View file

@ -0,0 +1,48 @@
//! GitHub provider implementation
//!
//! TODO: Implement GitHub-specific API format (/models) when needed
//! For now, uses OpenAI-compatible format as fallback
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
/// GitHub provider implementation
#[derive(Debug, Clone)]
pub struct GitHubProvider;
impl ProviderInterface for GitHubProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
// TODO: Update when GitHub API is fully implemented
matches!(api_path, "/v1/chat/completions" | "/models")
}
fn supported_apis(&self) -> Vec<&'static str> {
// TODO: Update when GitHub API is fully implemented
vec!["/models"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement GitHub-specific request parsing
match ChatCompletionsRequest::try_from_bytes(bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement GitHub-specific response parsing
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement GitHub-specific request serialization
match request.to_provider_bytes(provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -0,0 +1,6 @@
//! Groq provider implementation
//!
//! Groq uses OpenAI-compatible API format but with different endpoints
pub mod provider;
pub use provider::GroqProvider;

View file

@ -0,0 +1,43 @@
//! Groq provider implementation
//!
//! This module contains the Groq provider that handles Groq API format requests.
//! Groq uses OpenAI-compatible format but may have provider-specific nuances.
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
/// Groq provider implementation
#[derive(Debug, Clone)]
pub struct GroqProvider;
impl ProviderInterface for GroqProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
matches!(api_path, "/v1/chat/completions" | "/openai/v1/chat/completions")
}
fn supported_apis(&self) -> Vec<&'static str> {
vec!["/openai/v1/chat/completions"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsRequest::try_from_bytes(bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match request.to_provider_bytes(provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -1,109 +0,0 @@
//! Provider interface trait definitions
//!
//! This module defines the core traits that all LLM providers must implement.
//! The interface is designed around v1/chat/completions API for simplicity.
use std::error::Error;
/// Conversion mode for provider requests/responses
#[derive(Debug, Clone, Copy)]
pub enum ConversionMode {
/// Compatible: Convert between different provider formats to ensure compatibility
Compatible,
/// Passthrough: Pass requests/responses through with minimal modification
Passthrough,
}
/// Token usage information
pub trait TokenUsage {
fn completion_tokens(&self) -> usize;
fn prompt_tokens(&self) -> usize;
fn total_tokens(&self) -> usize;
}
/// Error type for provider operations
pub trait ProviderError: Error + Send + Sync + 'static {}
/// Request type that can be converted to/from provider-specific formats
pub trait ProviderRequest: Sized {
type Error: ProviderError;
/// Parse request from raw bytes (typically JSON)
fn from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>;
/// Convert to bytes for sending to upstream API
fn to_bytes(&self, mode: ConversionMode) -> Result<Vec<u8>, Self::Error>;
/// Extract the model name from the request
fn model(&self) -> &str;
/// Check if this is a streaming request
fn is_streaming(&self) -> bool;
/// Set streaming options (e.g., include_usage)
fn set_streaming_options(&mut self);
/// Extract text content from messages for token counting
fn extract_text(&self) -> String;
}
/// Response type that can be converted to/from provider-specific formats
pub trait ProviderResponse: Sized {
type Error: ProviderError;
type Usage: TokenUsage;
/// Parse response from raw bytes (typically JSON)
fn from_bytes(bytes: &[u8], mode: ConversionMode) -> Result<Self, Self::Error>;
/// Convert to bytes for sending to client
fn to_bytes(&self) -> Result<Vec<u8>, Self::Error>;
/// Get usage information if available
fn usage(&self) -> Option<&Self::Usage>;
}
/// Streaming response chunk
pub trait StreamChunk: Sized {
type Error: ProviderError;
type Usage: TokenUsage;
/// Parse chunk from a line of streaming data
fn from_line(line: &str, mode: ConversionMode) -> Result<Option<Self>, Self::Error>;
/// Convert to line for sending to client
fn to_line(&self) -> Result<String, Self::Error>;
/// Get usage information if available (usually only in final chunk)
fn usage(&self) -> Option<&Self::Usage>;
/// Check if this is the final chunk in the stream
fn is_final(&self) -> bool;
}
/// Main provider interface
pub trait LLMProvider {
type Request: ProviderRequest;
type Response: ProviderResponse;
type StreamChunk: StreamChunk;
type Error: ProviderError;
/// Create a new instance of this provider
fn new() -> Self;
/// Get the supported API endpoints for this provider
fn supported_apis(&self) -> Vec<&'static str>;
/// Check if the provider supports v1/chat/completions API
fn supports_chat_completions(&self) -> bool {
self.supported_apis().contains(&"/v1/chat/completions")
}
/// Parse a request from raw bytes
fn parse_request(&self, bytes: &[u8]) -> Result<Self::Request, Self::Error>;
/// Parse a response from raw bytes
fn parse_response(&self, bytes: &[u8], mode: ConversionMode) -> Result<Self::Response, Self::Error>;
/// Parse streaming response chunks from raw data
fn parse_stream_chunk(&self, line: &str, mode: ConversionMode) -> Result<Option<Self::StreamChunk>, Self::Error>;
}

View file

@ -0,0 +1,6 @@
//! Mistral provider implementation
//!
//! Mistral uses OpenAI-compatible API format
pub mod provider;
pub use provider::MistralProvider;

View file

@ -0,0 +1,40 @@
//! Mistral provider implementation
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
/// Mistral provider implementation
#[derive(Debug, Clone)]
pub struct MistralProvider;
impl ProviderInterface for MistralProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
matches!(api_path, "/v1/chat/completions")
}
fn supported_apis(&self) -> Vec<&'static str> {
vec!["/v1/chat/completions"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsRequest::try_from_bytes(bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match request.to_provider_bytes(provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}

View file

@ -5,10 +5,24 @@
pub mod traits;
pub mod openai;
pub mod groq;
pub mod mistral;
pub mod deepseek;
pub mod arch;
pub mod gemini;
pub mod claude;
pub mod github;
// Re-export the main interfaces
pub use traits::*;
pub use openai::OpenAIProvider;
pub use groq::GroqProvider;
pub use mistral::MistralProvider;
pub use deepseek::DeepseekProvider;
pub use arch::ArchProvider;
pub use gemini::GeminiProvider;
pub use claude::ClaudeProvider;
pub use github::GitHubProvider;
use std::fmt::Display;
@ -81,27 +95,29 @@ impl ProviderId {
}
/// Enum for dynamic dispatch of provider instances
/// For now, most providers use OpenAI-compatible format
pub enum Provider {
OpenAI(OpenAIProvider, ProviderId),
// TODO: Add specific implementations when providers have different APIs
// Mistral(MistralProvider, ProviderId),
// Groq(GroqProvider, ProviderId),
// etc.
Groq(GroqProvider, ProviderId),
Mistral(MistralProvider, ProviderId),
Deepseek(DeepseekProvider, ProviderId),
Arch(ArchProvider, ProviderId),
Gemini(GeminiProvider, ProviderId),
Claude(ClaudeProvider, ProviderId),
GitHub(GitHubProvider, ProviderId),
}
impl Provider {
/// Create a provider instance from a provider ID
pub fn new(id: ProviderId) -> Self {
match id {
// For now, all providers that support v1/chat/completions use OpenAI format
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch => {
Provider::OpenAI(OpenAIProvider, id)
}
// TODO: Implement specific providers when they have different APIs
ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
Provider::OpenAI(OpenAIProvider, id) // Fallback to OpenAI for now
}
ProviderId::OpenAI => Provider::OpenAI(OpenAIProvider, id),
ProviderId::Groq => Provider::Groq(GroqProvider, id),
ProviderId::Mistral => Provider::Mistral(MistralProvider, id),
ProviderId::Deepseek => Provider::Deepseek(DeepseekProvider, id),
ProviderId::Arch => Provider::Arch(ArchProvider, id),
ProviderId::Gemini => Provider::Gemini(GeminiProvider, id),
ProviderId::Claude => Provider::Claude(ClaudeProvider, id),
ProviderId::GitHub => Provider::GitHub(GitHubProvider, id),
}
}
@ -109,66 +125,27 @@ impl Provider {
pub fn id(&self) -> ProviderId {
match self {
Provider::OpenAI(_, id) => *id,
Provider::Groq(_, id) => *id,
Provider::Mistral(_, id) => *id,
Provider::Deepseek(_, id) => *id,
Provider::Arch(_, id) => *id,
Provider::Gemini(_, id) => *id,
Provider::Claude(_, id) => *id,
Provider::GitHub(_, id) => *id,
}
}
/// Check if this provider has a compatible API with the client request
pub fn has_compatible_api(&self, api_path: &str) -> bool {
/// Get the provider interface implementation
pub fn interface(&self) -> &dyn ProviderInterface {
match self {
Provider::OpenAI(provider, _) => provider.has_compatible_api(api_path),
}
}
/// Get the interface implementation for this provider
pub fn get_interface(&self, passthrough: bool) -> ConversionMode {
match self {
Provider::OpenAI(provider, _) => provider.get_interface(passthrough),
}
}
/// Parse a request from raw bytes - returns the concrete OpenAI request type for now
pub fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match self {
Provider::OpenAI(_, _) => {
use crate::apis::openai::ChatCompletionsRequest;
use crate::providers::traits::ProviderRequest;
match ChatCompletionsRequest::try_from_bytes(bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
}
}
/// Parse a response from raw bytes - returns the concrete OpenAI response type for now
pub fn parse_response(&self, bytes: &[u8], mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match self {
Provider::OpenAI(_, _) => {
use crate::apis::openai::ChatCompletionsResponse;
use crate::providers::traits::ProviderResponse;
let provider_id = self.id();
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
}
}
/// Convert a request to bytes for sending to upstream API
pub fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match self {
Provider::OpenAI(_, _) => {
use crate::providers::traits::ProviderRequest;
let provider_id = self.id();
match request.to_provider_bytes(provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
Provider::OpenAI(provider, _) => provider,
Provider::Groq(provider, _) => provider,
Provider::Mistral(provider, _) => provider,
Provider::Deepseek(provider, _) => provider,
Provider::Arch(provider, _) => provider,
Provider::Gemini(provider, _) => provider,
Provider::Claude(provider, _) => provider,
Provider::GitHub(provider, _) => provider,
}
}
}

View file

@ -68,11 +68,6 @@ impl Iterator for OpenAIStreamingResponse {
}
impl ProviderInterface for OpenAIProvider {
type Request = ChatCompletionsRequest;
type Response = ChatCompletionsResponse;
type StreamingResponse = OpenAIStreamingResponse;
type Usage = Usage;
fn has_compatible_api(&self, api_path: &str) -> bool {
api_path == "/v1/chat/completions"
}
@ -80,6 +75,30 @@ impl ProviderInterface for OpenAIProvider {
fn supported_apis(&self) -> Vec<&'static str> {
vec!["/v1/chat/completions"]
}
fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
use crate::providers::traits::ProviderRequest;
match ChatCompletionsRequest::try_from_bytes(bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
use crate::providers::traits::ProviderResponse;
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
use crate::providers::traits::ProviderRequest;
match request.to_provider_bytes(provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}
// ============================================================================
@ -131,6 +150,9 @@ impl ProviderRequest for ChatCompletionsRequest {
}
}
// Implement the helper trait for stream context integration
impl crate::providers::traits::StreamContextHelpers for ChatCompletionsRequest {}
impl TokenUsage for Usage {
fn completion_tokens(&self) -> usize {
self.completion_tokens as usize

View file

@ -54,6 +54,31 @@ pub trait ProviderResponse: Sized {
/// Get usage information if available
fn usage(&self) -> Option<&Self::Usage>;
/// Extract token counts for metrics
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
}
}
/// Helper trait for stream context integration
pub trait StreamContextHelpers: ProviderRequest {
/// Get the model name for routing and metrics
fn get_model_for_routing(&self) -> String {
self.extract_model().to_string()
}
/// Get text for token counting and rate limiting
fn get_text_for_tokenization(&self) -> String {
self.extract_messages_text()
}
/// Prepare for streaming by setting appropriate options
fn prepare_for_streaming(&mut self) {
if self.is_streaming() {
self.set_streaming_options();
}
}
}
/// Trait for streaming response chunks
@ -75,11 +100,6 @@ pub trait StreamingResponse: Iterator<Item = Result<Self::Chunk, Self::Error>> +
/// Main provider interface trait
pub trait ProviderInterface {
type Request: ProviderRequest;
type Response: ProviderResponse;
type StreamingResponse: StreamingResponse;
type Usage: TokenUsage;
/// Check if this provider has a compatible API with the client request
fn has_compatible_api(&self, api_path: &str) -> bool;
@ -93,6 +113,47 @@ pub trait ProviderInterface {
}
}
/// Parse a request from raw bytes - returns concrete ChatCompletionsRequest
fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>>;
/// Parse a response from raw bytes - returns concrete ChatCompletionsResponse
fn parse_response(&self, bytes: &[u8], provider_id: super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>>;
/// Convert a request to bytes for sending to upstream API
fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>;
/// Extract model name from request for routing (convenience method for stream_context)
fn extract_model_from_request(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String {
use ProviderRequest;
request.extract_model().to_string()
}
/// Check if request is streaming (convenience method for stream_context)
fn is_request_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool {
use ProviderRequest;
request.is_streaming()
}
/// Prepare request for streaming (convenience method for stream_context)
fn prepare_request_for_streaming(&self, request: &mut crate::apis::openai::ChatCompletionsRequest) {
use ProviderRequest;
if request.is_streaming() {
request.set_streaming_options();
}
}
/// Extract text for tokenization (convenience method for stream_context)
fn extract_text_for_tokenization(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String {
use ProviderRequest;
request.extract_messages_text()
}
/// Extract usage information from response (convenience method for stream_context)
fn extract_usage_from_response(&self, response: &crate::apis::openai::ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
use ProviderResponse;
response.extract_usage_counts()
}
/// Get supported API endpoints for this provider
fn supported_apis(&self) -> Vec<&'static str>;
}

View file

@ -10,7 +10,7 @@ use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::{ConversionMode, Provider, ProviderId, ProviderRequest};
use hermesllm::{ConversionMode, Provider, ProviderId};
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::hostcalls::get_current_time;
@ -295,7 +295,7 @@ impl HttpContext for StreamContext {
let provider = self.get_provider();
let mut deserialized_body = match provider.parse_request(&body_bytes) {
let mut deserialized_body = match provider.interface().parse_request(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!(
@ -310,8 +310,8 @@ impl HttpContext for StreamContext {
}
};
// TODO: For now, we'll need to handle user_message extraction differently since it's OpenAI-specific
// This could be made generic by adding a trait method later
// TODO: For now, we'll work with the concrete ChatCompletionsRequest type
// In the future, this could be made more generic using trait objects
let model_name = match self.llm_provider.as_ref() {
Some(llm_provider) => llm_provider.model.as_ref(),
@ -323,9 +323,10 @@ impl HttpContext for StreamContext {
None => false,
};
let model_requested = deserialized_body.extract_model().to_string();
// Note: We can't directly modify the model field through the trait,
// this would need to be handled differently in a full generic implementation
// Use the provider interface methods for cleaner interaction
let model_requested = provider
.interface()
.extract_model_from_request(&deserialized_body);
info!(
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
@ -334,15 +335,21 @@ impl HttpContext for StreamContext {
model_name.unwrap_or(&"None".to_string()),
);
if deserialized_body.is_streaming() {
// Use provider interface for streaming detection and setup
if provider
.interface()
.is_request_streaming(&deserialized_body)
{
self.streaming_response = true;
}
if deserialized_body.is_streaming() {
deserialized_body.set_streaming_options();
provider
.interface()
.prepare_request_for_streaming(&mut deserialized_body);
}
// only use the tokens from the messages, excluding the metadata and json tags
let input_tokens_str = deserialized_body.extract_messages_text();
// Use provider interface for text extraction
let input_tokens_str = provider
.interface()
.extract_text_for_tokenization(&deserialized_body);
// enforce ratelimits on ingress
if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) {
self.send_server_error(
@ -354,12 +361,14 @@ impl HttpContext for StreamContext {
}
let llm_provider_str = self.llm_provider().provider_interface.to_string();
let hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
// convert chat completion request to llm provider specific request
let deserialized_body_bytes = match deserialized_body
.to_provider_bytes(hermes_llm_provider_id, ConversionMode::Compatible)
{
// Convert chat completion request to llm provider specific request using provider interface
let deserialized_body_bytes = match provider.interface().request_to_bytes(
&deserialized_body,
provider.id(),
ConversionMode::Compatible,
) {
Ok(bytes) => bytes,
Err(e) => {
warn!("Failed to serialize request body: {}", e);
@ -558,8 +567,12 @@ impl HttpContext for StreamContext {
} else {
debug!("non streaming response");
let provider = self.get_provider();
let _response = match provider.parse_response(&body, ConversionMode::Compatible) {
Ok(response_box) => response_box,
let response = match provider.interface().parse_response(
&body,
provider.id(),
ConversionMode::Compatible,
) {
Ok(response) => response,
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
@ -579,9 +592,18 @@ impl HttpContext for StreamContext {
}
};
// TODO: Extract usage information from the response box
// For now, we'll skip this until we have a better way to handle Any types
warn!("Response token counting not yet implemented with new provider structure");
// Use provider interface to extract usage information
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
provider.interface().extract_usage_from_response(&response)
{
debug!(
"Response usage: prompt={}, completion={}, total={}",
prompt_tokens, completion_tokens, total_tokens
);
self.response_tokens = completion_tokens;
} else {
warn!("No usage information found in response");
}
}
debug!(