more refactoring changes to avoid unecessary re-direction and duplication

This commit is contained in:
Salman Paracha 2025-08-09 21:40:33 -07:00
parent 58028bb7ae
commit 9c09a18fd0
12 changed files with 809 additions and 225 deletions

View file

@ -46,18 +46,18 @@ mod tests {
#[test]
fn test_provider_instance_creation() {
let provider = Provider::new(ProviderId::OpenAI);
assert!(provider.interface().has_compatible_api("/v1/chat/completions"));
assert!(!provider.interface().has_compatible_api("/v1/embeddings"));
assert!(provider.has_compatible_api("/v1/chat/completions"));
assert!(!provider.has_compatible_api("/v1/embeddings"));
}
#[test]
fn test_conversion_mode() {
fn test_provider_supported_apis() {
let provider = Provider::new(ProviderId::OpenAI);
let compatible_mode = provider.interface().get_interface(false);
assert!(matches!(compatible_mode, ConversionMode::Compatible));
let supported_apis = provider.supported_apis();
assert!(supported_apis.contains(&"/v1/chat/completions"));
let passthrough_mode = provider.interface().get_interface(true);
assert!(matches!(passthrough_mode, ConversionMode::Passthrough));
// Test that provider supports the expected API endpoints
assert!(provider.has_compatible_api("/v1/chat/completions"));
}
}

View file

@ -1,13 +1,76 @@
//! Arch provider implementation
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
/// Arch provider implementation
#[derive(Debug, Clone)]
pub struct ArchProvider;
// Trait implementations that delegate to OpenAI
impl ProviderRequest for ArchProvider {
type Error = OpenAIApiError;
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::try_from_bytes(&openai_provider, bytes)
}
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
}
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
&request.model
}
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
request.stream.unwrap_or_default()
}
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
let openai_provider = OpenAIProvider;
ProviderRequest::set_streaming_options(&openai_provider, request)
}
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
}
impl ProviderResponse for ArchProvider {
type Error = OpenAIApiError;
type Usage = Usage;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
Some(&response.usage)
}
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
let openai_provider = OpenAIProvider;
ProviderResponse::extract_usage_counts(&openai_provider, response)
}
}
impl StreamingResponse for ArchProvider {
type Error = OpenAIApiError;
type StreamingIter = OpenAIStreamingResponse;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
let openai_provider = OpenAIProvider;
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
}
impl ProviderInterface for ArchProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
matches!(api_path, "/v1/chat/completions")
@ -18,21 +81,21 @@ impl ProviderInterface for ArchProvider {
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsRequest::try_from_bytes(bytes) {
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match request.to_provider_bytes(provider_id, mode) {
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}

View file

@ -4,13 +4,76 @@
//! For now, uses OpenAI-compatible format as fallback
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
/// Claude provider implementation
#[derive(Debug, Clone)]
pub struct ClaudeProvider;
// Trait implementations that delegate to OpenAI
impl ProviderRequest for ClaudeProvider {
type Error = OpenAIApiError;
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::try_from_bytes(&openai_provider, bytes)
}
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
}
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
&request.model
}
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
request.stream.unwrap_or_default()
}
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
let openai_provider = OpenAIProvider;
ProviderRequest::set_streaming_options(&openai_provider, request)
}
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
}
impl ProviderResponse for ClaudeProvider {
type Error = OpenAIApiError;
type Usage = Usage;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
Some(&response.usage)
}
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
let openai_provider = OpenAIProvider;
ProviderResponse::extract_usage_counts(&openai_provider, response)
}
}
impl StreamingResponse for ClaudeProvider {
type Error = OpenAIApiError;
type StreamingIter = OpenAIStreamingResponse;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
let openai_provider = OpenAIProvider;
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
}
impl ProviderInterface for ClaudeProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
// TODO: Update when Claude API is fully implemented
@ -24,7 +87,7 @@ impl ProviderInterface for ClaudeProvider {
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Claude-specific request parsing
match ChatCompletionsRequest::try_from_bytes(bytes) {
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
@ -32,7 +95,7 @@ impl ProviderInterface for ClaudeProvider {
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Claude-specific response parsing
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
@ -40,7 +103,7 @@ impl ProviderInterface for ClaudeProvider {
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Claude-specific request serialization
match request.to_provider_bytes(provider_id, mode) {
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}

View file

@ -1,13 +1,76 @@
//! Deepseek provider implementation
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
/// Deepseek provider implementation
#[derive(Debug, Clone)]
pub struct DeepseekProvider;
// Trait implementations that delegate to OpenAI
impl ProviderRequest for DeepseekProvider {
type Error = OpenAIApiError;
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::try_from_bytes(&openai_provider, bytes)
}
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
}
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
&request.model
}
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
request.stream.unwrap_or_default()
}
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
let openai_provider = OpenAIProvider;
ProviderRequest::set_streaming_options(&openai_provider, request)
}
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
}
impl ProviderResponse for DeepseekProvider {
type Error = OpenAIApiError;
type Usage = Usage;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
Some(&response.usage)
}
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
let openai_provider = OpenAIProvider;
ProviderResponse::extract_usage_counts(&openai_provider, response)
}
}
impl StreamingResponse for DeepseekProvider {
type Error = OpenAIApiError;
type StreamingIter = OpenAIStreamingResponse;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
let openai_provider = OpenAIProvider;
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
}
impl ProviderInterface for DeepseekProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
matches!(api_path, "/v1/chat/completions")
@ -18,21 +81,21 @@ impl ProviderInterface for DeepseekProvider {
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsRequest::try_from_bytes(bytes) {
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match request.to_provider_bytes(provider_id, mode) {
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}

View file

@ -1,16 +1,79 @@
//! Gemini provider implementation
//!
//! TODO: Implement Gemini-specific API format when needed
//! For now, uses OpenAI-compatible format as fallback
//! This module contains the Gemini provider that handles Google's Gemini API format
//! requests in OpenAI-compatible format.
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
/// Gemini provider implementation
#[derive(Debug, Clone)]
pub struct GeminiProvider;
// Trait implementations that delegate to OpenAI
impl ProviderRequest for GeminiProvider {
type Error = OpenAIApiError;
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::try_from_bytes(&openai_provider, bytes)
}
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
}
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
&request.model
}
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
request.stream.unwrap_or_default()
}
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
let openai_provider = OpenAIProvider;
ProviderRequest::set_streaming_options(&openai_provider, request)
}
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
}
impl ProviderResponse for GeminiProvider {
type Error = OpenAIApiError;
type Usage = Usage;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
Some(&response.usage)
}
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
let openai_provider = OpenAIProvider;
ProviderResponse::extract_usage_counts(&openai_provider, response)
}
}
impl StreamingResponse for GeminiProvider {
type Error = OpenAIApiError;
type StreamingIter = OpenAIStreamingResponse;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
let openai_provider = OpenAIProvider;
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
}
impl ProviderInterface for GeminiProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
// TODO: Update when Gemini API is fully implemented
@ -24,7 +87,7 @@ impl ProviderInterface for GeminiProvider {
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Gemini-specific request parsing
match ChatCompletionsRequest::try_from_bytes(bytes) {
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
@ -32,7 +95,7 @@ impl ProviderInterface for GeminiProvider {
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Gemini-specific response parsing
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
@ -40,7 +103,7 @@ impl ProviderInterface for GeminiProvider {
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement Gemini-specific request serialization
match request.to_provider_bytes(provider_id, mode) {
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}

View file

@ -1,16 +1,79 @@
//! GitHub provider implementation
//!
//! TODO: Implement GitHub-specific API format (/models) when needed
//! For now, uses OpenAI-compatible format as fallback
//! This module contains the GitHub provider that handles GitHub API format
//! requests in OpenAI-compatible format.
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
/// GitHub provider implementation
#[derive(Debug, Clone)]
pub struct GitHubProvider;
// Trait implementations that delegate to OpenAI
impl ProviderRequest for GitHubProvider {
type Error = OpenAIApiError;
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::try_from_bytes(&openai_provider, bytes)
}
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
}
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
&request.model
}
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
request.stream.unwrap_or_default()
}
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
let openai_provider = OpenAIProvider;
ProviderRequest::set_streaming_options(&openai_provider, request)
}
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
}
impl ProviderResponse for GitHubProvider {
type Error = OpenAIApiError;
type Usage = Usage;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
Some(&response.usage)
}
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
let openai_provider = OpenAIProvider;
ProviderResponse::extract_usage_counts(&openai_provider, response)
}
}
impl StreamingResponse for GitHubProvider {
type Error = OpenAIApiError;
type StreamingIter = OpenAIStreamingResponse;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
let openai_provider = OpenAIProvider;
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
}
impl ProviderInterface for GitHubProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
// TODO: Update when GitHub API is fully implemented
@ -24,7 +87,7 @@ impl ProviderInterface for GitHubProvider {
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement GitHub-specific request parsing
match ChatCompletionsRequest::try_from_bytes(bytes) {
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
@ -32,7 +95,7 @@ impl ProviderInterface for GitHubProvider {
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement GitHub-specific response parsing
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
@ -40,7 +103,7 @@ impl ProviderInterface for GitHubProvider {
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
// TODO: Implement GitHub-specific request serialization
match request.to_provider_bytes(provider_id, mode) {
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}

View file

@ -4,13 +4,76 @@
//! Groq uses OpenAI-compatible format but may have provider-specific nuances.
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
/// Groq provider implementation
#[derive(Debug, Clone)]
pub struct GroqProvider;
// Trait implementations that delegate to OpenAI
impl ProviderRequest for GroqProvider {
type Error = OpenAIApiError;
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::try_from_bytes(&openai_provider, bytes)
}
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
}
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
&request.model
}
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
request.stream.unwrap_or_default()
}
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
let openai_provider = OpenAIProvider;
ProviderRequest::set_streaming_options(&openai_provider, request)
}
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
}
impl ProviderResponse for GroqProvider {
type Error = OpenAIApiError;
type Usage = Usage;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
Some(&response.usage)
}
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
let openai_provider = OpenAIProvider;
ProviderResponse::extract_usage_counts(&openai_provider, response)
}
}
impl StreamingResponse for GroqProvider {
type Error = OpenAIApiError;
type StreamingIter = OpenAIStreamingResponse;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
let openai_provider = OpenAIProvider;
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
}
impl ProviderInterface for GroqProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
matches!(api_path, "/v1/chat/completions" | "/openai/v1/chat/completions")
@ -21,21 +84,21 @@ impl ProviderInterface for GroqProvider {
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsRequest::try_from_bytes(bytes) {
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match request.to_provider_bytes(provider_id, mode) {
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}

View file

@ -1,13 +1,76 @@
//! Mistral provider implementation
use crate::providers::{ProviderInterface, ConversionMode};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
use crate::providers::traits::{ProviderRequest, ProviderResponse};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
/// Mistral provider implementation
#[derive(Debug, Clone)]
pub struct MistralProvider;
// Trait implementations that delegate to OpenAI
impl ProviderRequest for MistralProvider {
type Error = OpenAIApiError;
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::try_from_bytes(&openai_provider, bytes)
}
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
}
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
&request.model
}
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
request.stream.unwrap_or_default()
}
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
let openai_provider = OpenAIProvider;
ProviderRequest::set_streaming_options(&openai_provider, request)
}
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
let openai_provider = OpenAIProvider;
ProviderRequest::extract_messages_text(&openai_provider, request)
}
}
impl ProviderResponse for MistralProvider {
type Error = OpenAIApiError;
type Usage = Usage;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
let openai_provider = OpenAIProvider;
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
Some(&response.usage)
}
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
let openai_provider = OpenAIProvider;
ProviderResponse::extract_usage_counts(&openai_provider, response)
}
}
impl StreamingResponse for MistralProvider {
type Error = OpenAIApiError;
type StreamingIter = OpenAIStreamingResponse;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
let openai_provider = OpenAIProvider;
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
}
}
impl ProviderInterface for MistralProvider {
fn has_compatible_api(&self, api_path: &str) -> bool {
matches!(api_path, "/v1/chat/completions")
@ -18,21 +81,21 @@ impl ProviderInterface for MistralProvider {
}
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsRequest::try_from_bytes(bytes) {
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match request.to_provider_bytes(provider_id, mode) {
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}

View file

@ -134,18 +134,140 @@ impl Provider {
Provider::GitHub(_, id) => *id,
}
}
}
/// Get the provider interface implementation
pub fn interface(&self) -> &dyn ProviderInterface {
// Implement traits directly on the Provider enum
impl ProviderRequest for Provider {
type Error = openai::provider::OpenAIApiError;
fn try_from_bytes(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Self::Error> {
match self {
Provider::OpenAI(provider, _) => provider,
Provider::Groq(provider, _) => provider,
Provider::Mistral(provider, _) => provider,
Provider::Deepseek(provider, _) => provider,
Provider::Arch(provider, _) => provider,
Provider::Gemini(provider, _) => provider,
Provider::Claude(provider, _) => provider,
Provider::GitHub(provider, _) => provider,
Provider::OpenAI(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
Provider::Groq(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
Provider::Mistral(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
Provider::Deepseek(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
Provider::Arch(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
Provider::Gemini(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
Provider::Claude(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
Provider::GitHub(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
}
}
fn to_provider_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
match self {
Provider::OpenAI(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
Provider::Groq(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
Provider::Mistral(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
Provider::Deepseek(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
Provider::Arch(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
Provider::Gemini(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
Provider::Claude(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
Provider::GitHub(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
}
}
fn extract_model<'a>(&self, request: &'a crate::apis::openai::ChatCompletionsRequest) -> &'a str {
// Since all providers use the same implementation, just use the first one
&request.model
}
fn is_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool {
// Since all providers use the same implementation, just use the first one
request.stream.unwrap_or_default()
}
fn set_streaming_options(&self, request: &mut crate::apis::openai::ChatCompletionsRequest) {
match self {
Provider::OpenAI(provider, _) => ProviderRequest::set_streaming_options(provider, request),
Provider::Groq(provider, _) => ProviderRequest::set_streaming_options(provider, request),
Provider::Mistral(provider, _) => ProviderRequest::set_streaming_options(provider, request),
Provider::Deepseek(provider, _) => ProviderRequest::set_streaming_options(provider, request),
Provider::Arch(provider, _) => ProviderRequest::set_streaming_options(provider, request),
Provider::Gemini(provider, _) => ProviderRequest::set_streaming_options(provider, request),
Provider::Claude(provider, _) => ProviderRequest::set_streaming_options(provider, request),
Provider::GitHub(provider, _) => ProviderRequest::set_streaming_options(provider, request),
}
}
fn extract_messages_text(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String {
match self {
Provider::OpenAI(provider, _) => ProviderRequest::extract_messages_text(provider, request),
Provider::Groq(provider, _) => ProviderRequest::extract_messages_text(provider, request),
Provider::Mistral(provider, _) => ProviderRequest::extract_messages_text(provider, request),
Provider::Deepseek(provider, _) => ProviderRequest::extract_messages_text(provider, request),
Provider::Arch(provider, _) => ProviderRequest::extract_messages_text(provider, request),
Provider::Gemini(provider, _) => ProviderRequest::extract_messages_text(provider, request),
Provider::Claude(provider, _) => ProviderRequest::extract_messages_text(provider, request),
Provider::GitHub(provider, _) => ProviderRequest::extract_messages_text(provider, request),
}
}
}
impl ProviderResponse for Provider {
type Error = openai::provider::OpenAIApiError;
type Usage = crate::apis::openai::Usage;
fn try_from_bytes(&self, bytes: &[u8], provider_id: &super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Self::Error> {
match self {
Provider::OpenAI(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::Groq(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::Mistral(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::Deepseek(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::Arch(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::Gemini(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::Claude(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::GitHub(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
}
}
fn usage<'a>(&self, response: &'a crate::apis::openai::ChatCompletionsResponse) -> Option<&'a Self::Usage> {
// Since all providers use the same implementation, just use the direct implementation
Some(&response.usage)
}
}
impl StreamingResponse for Provider {
type Error = openai::provider::OpenAIApiError;
type StreamingIter = openai::provider::OpenAIStreamingResponse;
fn try_from_bytes(&self, bytes: &[u8], provider_id: &super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
match self {
Provider::OpenAI(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::Groq(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::Mistral(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::Deepseek(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::Arch(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::Gemini(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::Claude(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
Provider::GitHub(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
}
}
}
impl ProviderInterface for Provider {
fn has_compatible_api(&self, api_path: &str) -> bool {
match self {
Provider::OpenAI(provider, _) => provider.has_compatible_api(api_path),
Provider::Groq(provider, _) => provider.has_compatible_api(api_path),
Provider::Mistral(provider, _) => provider.has_compatible_api(api_path),
Provider::Deepseek(provider, _) => provider.has_compatible_api(api_path),
Provider::Arch(provider, _) => provider.has_compatible_api(api_path),
Provider::Gemini(provider, _) => provider.has_compatible_api(api_path),
Provider::Claude(provider, _) => provider.has_compatible_api(api_path),
Provider::GitHub(provider, _) => provider.has_compatible_api(api_path),
}
}
fn supported_apis(&self) -> Vec<&'static str> {
match self {
Provider::OpenAI(provider, _) => provider.supported_apis(),
Provider::Groq(provider, _) => provider.supported_apis(),
Provider::Mistral(provider, _) => provider.supported_apis(),
Provider::Deepseek(provider, _) => provider.supported_apis(),
Provider::Arch(provider, _) => provider.supported_apis(),
Provider::Gemini(provider, _) => provider.supported_apis(),
Provider::Claude(provider, _) => provider.supported_apis(),
Provider::GitHub(provider, _) => provider.supported_apis(),
}
}
}

View file

@ -77,64 +77,58 @@ impl ProviderInterface for OpenAIProvider {
}
fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
use crate::providers::traits::ProviderRequest;
match ChatCompletionsRequest::try_from_bytes(bytes) {
match ProviderRequest::try_from_bytes(self, bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
use crate::providers::traits::ProviderResponse;
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
use crate::providers::traits::ProviderRequest;
match request.to_provider_bytes(provider_id, mode) {
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}
// ============================================================================
// Trait Implementations for OpenAI Types
// ============================================================================
impl ProviderRequest for ChatCompletionsRequest {
// Direct trait implementations on OpenAIProvider
impl ProviderRequest for OpenAIProvider {
type Error = OpenAIApiError;
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
let s = std::str::from_utf8(bytes)?;
Ok(serde_json::from_str(s)?)
}
fn to_provider_bytes(&self, _provider: super::super::ProviderId, _mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
Ok(serde_json::to_vec(self)?)
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, _provider: super::super::ProviderId, _mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
Ok(serde_json::to_vec(request)?)
}
fn extract_model(&self) -> &str {
&self.model
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
&request.model
}
fn is_streaming(&self) -> bool {
self.stream.unwrap_or_default()
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
request.stream.unwrap_or_default()
}
fn set_streaming_options(&mut self) {
if self.stream_options.is_none() {
self.stream_options = Some(StreamOptions {
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
if request.stream_options.is_none() {
request.stream_options = Some(StreamOptions {
include_usage: Some(true),
});
}
}
fn extract_messages_text(&self) -> String {
self.messages
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
request.messages
.iter()
.fold(String::new(), |acc, m| {
acc + " " + &match &m.content {
@ -150,8 +144,41 @@ impl ProviderRequest for ChatCompletionsRequest {
}
}
// Implement the helper trait for stream context integration
impl crate::providers::traits::StreamContextHelpers for ChatCompletionsRequest {}
impl ProviderResponse for OpenAIProvider {
type Error = OpenAIApiError;
type Usage = Usage;
fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
let s = std::str::from_utf8(bytes)?;
Ok(serde_json::from_str(s)?)
}
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
Some(&response.usage)
}
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
Some((
response.usage.prompt_tokens as usize,
response.usage.completion_tokens as usize,
response.usage.total_tokens as usize,
))
}
}
impl StreamingResponse for OpenAIProvider {
type Error = OpenAIApiError;
type StreamingIter = OpenAIStreamingResponse;
fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
let s = std::str::from_utf8(bytes)?;
Ok(OpenAIStreamingResponse::new(s.to_string()))
}
}
// ============================================================================
// Trait Implementations for OpenAI Types (Keep for TokenUsage only)
// ============================================================================
impl TokenUsage for Usage {
fn completion_tokens(&self) -> usize {
@ -167,20 +194,6 @@ impl TokenUsage for Usage {
}
}
impl ProviderResponse for ChatCompletionsResponse {
type Error = OpenAIApiError;
type Usage = Usage;
fn try_from_bytes(bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self, Self::Error> {
let s = std::str::from_utf8(bytes)?;
Ok(serde_json::from_str(s)?)
}
fn usage(&self) -> Option<&Self::Usage> {
Some(&self.usage)
}
}
impl StreamChunk for ChatCompletionsStreamResponse {
type Usage = Usage;
@ -191,9 +204,9 @@ impl StreamChunk for ChatCompletionsStreamResponse {
impl StreamingResponse for OpenAIStreamingResponse {
type Error = OpenAIApiError;
type Chunk = ChatCompletionsStreamResponse;
type StreamingIter = OpenAIStreamingResponse;
fn try_from_bytes(bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self, Self::Error> {
fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
let s = std::str::from_utf8(bytes)?;
Ok(OpenAIStreamingResponse::new(s.to_string()))
}

View file

@ -15,26 +15,26 @@ pub enum ConversionMode {
}
/// Trait for provider-specific request types
pub trait ProviderRequest: Sized {
pub trait ProviderRequest {
type Error: Error + Send + Sync + 'static;
/// Parse request from raw bytes
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>;
fn try_from_bytes(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Self::Error>;
/// Convert to provider-specific format
fn to_provider_bytes(&self, provider: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error>;
fn to_provider_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error>;
/// Extract the model name from the request
fn extract_model(&self) -> &str;
fn extract_model<'a>(&self, request: &'a crate::apis::openai::ChatCompletionsRequest) -> &'a str;
/// Check if this is a streaming request
fn is_streaming(&self) -> bool;
fn is_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool;
/// Set streaming options (e.g., include_usage)
fn set_streaming_options(&mut self);
fn set_streaming_options(&self, request: &mut crate::apis::openai::ChatCompletionsRequest);
/// Extract text content from messages for token counting
fn extract_messages_text(&self) -> String;
fn extract_messages_text(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String;
}
/// Trait for token usage information
@ -45,39 +45,19 @@ pub trait TokenUsage {
}
/// Trait for provider-specific response types
pub trait ProviderResponse: Sized {
pub trait ProviderResponse {
type Error: Error + Send + Sync + 'static;
type Usage: TokenUsage;
/// Parse response from raw bytes
fn try_from_bytes(bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<Self, Self::Error>;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Self::Error>;
/// Get usage information if available
fn usage(&self) -> Option<&Self::Usage>;
fn usage<'a>(&self, response: &'a crate::apis::openai::ChatCompletionsResponse) -> Option<&'a Self::Usage>;
/// Extract token counts for metrics
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
}
}
/// Helper trait for stream context integration
pub trait StreamContextHelpers: ProviderRequest {
/// Get the model name for routing and metrics
fn get_model_for_routing(&self) -> String {
self.extract_model().to_string()
}
/// Get text for token counting and rate limiting
fn get_text_for_tokenization(&self) -> String {
self.extract_messages_text()
}
/// Prepare for streaming by setting appropriate options
fn prepare_for_streaming(&mut self) {
if self.is_streaming() {
self.set_streaming_options();
}
fn extract_usage_counts(&self, response: &crate::apis::openai::ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
self.usage(response).map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
}
}
@ -90,70 +70,34 @@ pub trait StreamChunk {
}
/// Trait for streaming response iterators
pub trait StreamingResponse: Iterator<Item = Result<Self::Chunk, Self::Error>> + Sized {
pub trait StreamingResponse {
type Error: Error + Send + Sync + 'static;
type Chunk: StreamChunk;
type StreamingIter: Iterator<Item = Result<crate::apis::openai::ChatCompletionsStreamResponse, Self::Error>>;
/// Parse streaming response from raw bytes
fn try_from_bytes(bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<Self, Self::Error>;
fn try_from_bytes(&self, bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error>;
}
/// Main provider interface trait
pub trait ProviderInterface {
/// Main provider interface trait - simplified to only essential methods
pub trait ProviderInterface: ProviderRequest + ProviderResponse + StreamingResponse {
/// Check if this provider has a compatible API with the client request
fn has_compatible_api(&self, api_path: &str) -> bool;
/// Get the interface implementation for this provider
/// passthrough: if true, use provider-specific format; if false, use compatible format
fn get_interface(&self, passthrough: bool) -> ConversionMode {
if passthrough {
ConversionMode::Passthrough
} else {
ConversionMode::Compatible
}
}
/// Parse a request from raw bytes - returns concrete ChatCompletionsRequest
fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>>;
/// Parse a response from raw bytes - returns concrete ChatCompletionsResponse
fn parse_response(&self, bytes: &[u8], provider_id: super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>>;
/// Convert a request to bytes for sending to upstream API
fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>;
/// Extract model name from request for routing (convenience method for stream_context)
fn extract_model_from_request(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String {
use ProviderRequest;
request.extract_model().to_string()
}
/// Check if request is streaming (convenience method for stream_context)
fn is_request_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool {
use ProviderRequest;
request.is_streaming()
}
/// Prepare request for streaming (convenience method for stream_context)
fn prepare_request_for_streaming(&self, request: &mut crate::apis::openai::ChatCompletionsRequest) {
use ProviderRequest;
if request.is_streaming() {
request.set_streaming_options();
}
}
/// Extract text for tokenization (convenience method for stream_context)
fn extract_text_for_tokenization(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String {
use ProviderRequest;
request.extract_messages_text()
}
/// Extract usage information from response (convenience method for stream_context)
fn extract_usage_from_response(&self, response: &crate::apis::openai::ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
use ProviderResponse;
response.extract_usage_counts()
}
/// Get supported API endpoints for this provider
fn supported_apis(&self) -> Vec<&'static str>;
/// Parse a request from raw bytes - delegates to ProviderRequest
fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
ProviderRequest::try_from_bytes(self, bytes).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
/// Parse a response from raw bytes - delegates to ProviderResponse
fn parse_response(&self, bytes: &[u8], provider_id: super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
/// Convert a request to bytes - delegates to ProviderRequest
fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
ProviderRequest::to_provider_bytes(self, request, provider_id, mode).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
}

View file

@ -10,6 +10,10 @@ use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::apis::openai::{ContentPart, MessageContent};
use hermesllm::providers::traits::{
ProviderRequest, ProviderResponse, StreamChunk, StreamingResponse, TokenUsage,
};
use hermesllm::{ConversionMode, Provider, ProviderId};
use http::StatusCode;
use log::{debug, info, warn};
@ -39,6 +43,7 @@ pub struct StreamContext {
request_body_sent_time: Option<u128>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
user_message: Option<String>,
}
impl StreamContext {
@ -66,6 +71,7 @@ impl StreamContext {
ttft_time: None,
traces_queue,
request_body_sent_time: None,
user_message: None,
}
}
fn llm_provider(&self) -> &LlmProvider {
@ -295,7 +301,7 @@ impl HttpContext for StreamContext {
let provider = self.get_provider();
let mut deserialized_body = match provider.interface().parse_request(&body_bytes) {
let mut deserialized_body = match ProviderRequest::try_from_bytes(&provider, &body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!(
@ -324,9 +330,29 @@ impl HttpContext for StreamContext {
};
// Use the provider interface methods for cleaner interaction
let model_requested = provider
.interface()
.extract_model_from_request(&deserialized_body);
let model_requested = provider.extract_model(&deserialized_body).to_string(); // Convert to owned string
// Extract user message for tracing
self.user_message = deserialized_body.messages.last().and_then(|msg| {
match &msg.content {
MessageContent::Text(text) => Some(text.clone()),
MessageContent::Parts(parts) => {
// Extract text from content parts, ignoring images
let text_parts: Vec<String> = parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
ContentPart::ImageUrl { .. } => None,
})
.collect();
if text_parts.is_empty() {
None
} else {
Some(text_parts.join(" "))
}
}
}
});
info!(
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
@ -336,20 +362,15 @@ impl HttpContext for StreamContext {
);
// Use provider interface for streaming detection and setup
if provider
.interface()
.is_request_streaming(&deserialized_body)
{
self.streaming_response = true;
provider
.interface()
.prepare_request_for_streaming(&mut deserialized_body);
self.streaming_response = provider.is_streaming(&deserialized_body);
// Set streaming options if needed
if self.streaming_response {
provider.set_streaming_options(&mut deserialized_body);
}
// Use provider interface for text extraction
let input_tokens_str = provider
.interface()
.extract_text_for_tokenization(&deserialized_body);
// Use provider interface for text extraction (after potential mutation)
let input_tokens_str = provider.extract_messages_text(&deserialized_body);
// enforce ratelimits on ingress
if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) {
self.send_server_error(
@ -364,7 +385,7 @@ impl HttpContext for StreamContext {
let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
// Convert chat completion request to llm provider specific request using provider interface
let deserialized_body_bytes = match provider.interface().request_to_bytes(
let deserialized_body_bytes = match provider.to_provider_bytes(
&deserialized_body,
provider.id(),
ConversionMode::Compatible,
@ -473,6 +494,11 @@ impl HttpContext for StreamContext {
self.llm_provider().name.to_string(),
);
if let Some(user_message) = &self.user_message {
llm_span
.add_attribute("user_message".to_string(), user_message.clone());
}
if self.ttft_time.is_some() {
llm_span.add_event(Event::new(
"time_to_first_token".to_string(),
@ -540,36 +566,74 @@ impl HttpContext for StreamContext {
let _provider_id = ProviderId::from(llm_provider_str.as_str());
if self.streaming_response {
// TODO: Implement streaming response parsing with new provider structure
warn!(
"Streaming response parsing not yet fully implemented with new provider structure"
);
debug!("processing streaming response");
// For now, just compute TTFT and continue
if self.ttft_duration.is_none() {
let current_time = get_current_time().unwrap();
self.ttft_time = Some(current_time_ns());
match current_time.duration_since(self.start_time) {
Ok(duration) => {
let duration_ms = duration.as_millis();
info!(
"on_http_response_body: time to first token: {}ms",
duration_ms
);
self.ttft_duration = Some(duration);
self.metrics.time_to_first_token.record(duration_ms as u64);
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
// Parse streaming response using OpenAI-compatible format
// Since all providers use OpenAI-compatible streaming format
let provider = self.get_provider();
let provider_id =
ProviderId::from(self.llm_provider().provider_interface.to_string().as_str());
match StreamingResponse::try_from_bytes(
&provider,
&body,
&provider_id,
ConversionMode::Compatible,
) {
Ok(mut streaming_response) => {
// Process each streaming chunk
while let Some(chunk_result) = streaming_response.next() {
match chunk_result {
Ok(chunk) => {
// Compute TTFT on first chunk
if self.ttft_duration.is_none() {
let current_time = get_current_time().unwrap();
self.ttft_time = Some(current_time_ns());
match current_time.duration_since(self.start_time) {
Ok(duration) => {
let duration_ms = duration.as_millis();
info!(
"on_http_response_body: time to first token: {}ms",
duration_ms
);
self.ttft_duration = Some(duration);
self.metrics
.time_to_first_token
.record(duration_ms as u64);
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
}
}
}
// Extract usage information if available
if let Some(usage) = chunk.usage() {
let completion_tokens = usage.completion_tokens();
self.response_tokens += completion_tokens;
debug!(
"Streaming chunk completion tokens: {}",
completion_tokens
);
}
}
Err(e) => {
warn!("Error processing streaming chunk: {}", e);
}
}
}
}
Err(e) => {
warn!("Failed to parse streaming response: {}", e);
}
}
} else {
debug!("non streaming response");
let provider = self.get_provider();
let response = match provider.interface().parse_response(
let response = match ProviderResponse::try_from_bytes(
&provider,
&body,
provider.id(),
&provider.id(),
ConversionMode::Compatible,
) {
Ok(response) => response,
@ -594,7 +658,7 @@ impl HttpContext for StreamContext {
// Use provider interface to extract usage information
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
provider.interface().extract_usage_from_response(&response)
provider.extract_usage_counts(&response)
{
debug!(
"Response usage: prompt={}, completion={}, total={}",