mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
more refactoring changes to avoid unecessary re-direction and duplication
This commit is contained in:
parent
58028bb7ae
commit
9c09a18fd0
12 changed files with 809 additions and 225 deletions
|
|
@ -46,18 +46,18 @@ mod tests {
|
|||
#[test]
|
||||
fn test_provider_instance_creation() {
|
||||
let provider = Provider::new(ProviderId::OpenAI);
|
||||
assert!(provider.interface().has_compatible_api("/v1/chat/completions"));
|
||||
assert!(!provider.interface().has_compatible_api("/v1/embeddings"));
|
||||
assert!(provider.has_compatible_api("/v1/chat/completions"));
|
||||
assert!(!provider.has_compatible_api("/v1/embeddings"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversion_mode() {
|
||||
fn test_provider_supported_apis() {
|
||||
let provider = Provider::new(ProviderId::OpenAI);
|
||||
|
||||
let compatible_mode = provider.interface().get_interface(false);
|
||||
assert!(matches!(compatible_mode, ConversionMode::Compatible));
|
||||
let supported_apis = provider.supported_apis();
|
||||
assert!(supported_apis.contains(&"/v1/chat/completions"));
|
||||
|
||||
let passthrough_mode = provider.interface().get_interface(true);
|
||||
assert!(matches!(passthrough_mode, ConversionMode::Passthrough));
|
||||
// Test that provider supports the expected API endpoints
|
||||
assert!(provider.has_compatible_api("/v1/chat/completions"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,76 @@
|
|||
//! Arch provider implementation
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// Arch provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ArchProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for ArchProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for ArchProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for ArchProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for ArchProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
matches!(api_path, "/v1/chat/completions")
|
||||
|
|
@ -18,21 +81,21 @@ impl ProviderInterface for ArchProvider {
|
|||
}
|
||||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
match ProviderRequest::try_from_bytes(self, bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,13 +4,76 @@
|
|||
//! For now, uses OpenAI-compatible format as fallback
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// Claude provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClaudeProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for ClaudeProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for ClaudeProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for ClaudeProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for ClaudeProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
// TODO: Update when Claude API is fully implemented
|
||||
|
|
@ -24,7 +87,7 @@ impl ProviderInterface for ClaudeProvider {
|
|||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement Claude-specific request parsing
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
match ProviderRequest::try_from_bytes(self, bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
|
|
@ -32,7 +95,7 @@ impl ProviderInterface for ClaudeProvider {
|
|||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement Claude-specific response parsing
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
|
|
@ -40,7 +103,7 @@ impl ProviderInterface for ClaudeProvider {
|
|||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement Claude-specific request serialization
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,76 @@
|
|||
//! Deepseek provider implementation
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// Deepseek provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DeepseekProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for DeepseekProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for DeepseekProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for DeepseekProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for DeepseekProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
matches!(api_path, "/v1/chat/completions")
|
||||
|
|
@ -18,21 +81,21 @@ impl ProviderInterface for DeepseekProvider {
|
|||
}
|
||||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
match ProviderRequest::try_from_bytes(self, bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,16 +1,79 @@
|
|||
//! Gemini provider implementation
|
||||
//!
|
||||
//! TODO: Implement Gemini-specific API format when needed
|
||||
//! For now, uses OpenAI-compatible format as fallback
|
||||
//! This module contains the Gemini provider that handles Google's Gemini API format
|
||||
//! requests in OpenAI-compatible format.
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// Gemini provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GeminiProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for GeminiProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for GeminiProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for GeminiProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for GeminiProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
// TODO: Update when Gemini API is fully implemented
|
||||
|
|
@ -24,7 +87,7 @@ impl ProviderInterface for GeminiProvider {
|
|||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement Gemini-specific request parsing
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
match ProviderRequest::try_from_bytes(self, bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
|
|
@ -32,7 +95,7 @@ impl ProviderInterface for GeminiProvider {
|
|||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement Gemini-specific response parsing
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
|
|
@ -40,7 +103,7 @@ impl ProviderInterface for GeminiProvider {
|
|||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement Gemini-specific request serialization
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,16 +1,79 @@
|
|||
//! GitHub provider implementation
|
||||
//!
|
||||
//! TODO: Implement GitHub-specific API format (/models) when needed
|
||||
//! For now, uses OpenAI-compatible format as fallback
|
||||
//! This module contains the GitHub provider that handles GitHub API format
|
||||
//! requests in OpenAI-compatible format.
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// GitHub provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GitHubProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for GitHubProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for GitHubProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for GitHubProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for GitHubProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
// TODO: Update when GitHub API is fully implemented
|
||||
|
|
@ -24,7 +87,7 @@ impl ProviderInterface for GitHubProvider {
|
|||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement GitHub-specific request parsing
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
match ProviderRequest::try_from_bytes(self, bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
|
|
@ -32,7 +95,7 @@ impl ProviderInterface for GitHubProvider {
|
|||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement GitHub-specific response parsing
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
|
|
@ -40,7 +103,7 @@ impl ProviderInterface for GitHubProvider {
|
|||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// TODO: Implement GitHub-specific request serialization
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,13 +4,76 @@
|
|||
//! Groq uses OpenAI-compatible format but may have provider-specific nuances.
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// Groq provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GroqProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for GroqProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for GroqProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for GroqProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for GroqProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
matches!(api_path, "/v1/chat/completions" | "/openai/v1/chat/completions")
|
||||
|
|
@ -21,21 +84,21 @@ impl ProviderInterface for GroqProvider {
|
|||
}
|
||||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
match ProviderRequest::try_from_bytes(self, bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,76 @@
|
|||
//! Mistral provider implementation
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// Mistral provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MistralProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for MistralProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for MistralProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for MistralProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for MistralProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
matches!(api_path, "/v1/chat/completions")
|
||||
|
|
@ -18,21 +81,21 @@ impl ProviderInterface for MistralProvider {
|
|||
}
|
||||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
match ProviderRequest::try_from_bytes(self, bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -134,18 +134,140 @@ impl Provider {
|
|||
Provider::GitHub(_, id) => *id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the provider interface implementation
|
||||
pub fn interface(&self) -> &dyn ProviderInterface {
|
||||
// Implement traits directly on the Provider enum
|
||||
impl ProviderRequest for Provider {
|
||||
type Error = openai::provider::OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Self::Error> {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => provider,
|
||||
Provider::Groq(provider, _) => provider,
|
||||
Provider::Mistral(provider, _) => provider,
|
||||
Provider::Deepseek(provider, _) => provider,
|
||||
Provider::Arch(provider, _) => provider,
|
||||
Provider::Gemini(provider, _) => provider,
|
||||
Provider::Claude(provider, _) => provider,
|
||||
Provider::GitHub(provider, _) => provider,
|
||||
Provider::OpenAI(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::Groq(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::Mistral(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::Deepseek(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::Arch(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::Gemini(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::Claude(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::GitHub(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::Groq(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::Mistral(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::Deepseek(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::Arch(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::Gemini(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::Claude(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::GitHub(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a crate::apis::openai::ChatCompletionsRequest) -> &'a str {
|
||||
// Since all providers use the same implementation, just use the first one
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool {
|
||||
// Since all providers use the same implementation, just use the first one
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut crate::apis::openai::ChatCompletionsRequest) {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::Groq(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::Mistral(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::Deepseek(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::Arch(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::Gemini(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::Claude(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::GitHub(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::Groq(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::Mistral(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::Deepseek(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::Arch(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::Gemini(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::Claude(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::GitHub(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for Provider {
|
||||
type Error = openai::provider::OpenAIApiError;
|
||||
type Usage = crate::apis::openai::Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider_id: &super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Self::Error> {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Groq(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Mistral(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Deepseek(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Arch(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Gemini(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Claude(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::GitHub(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
}
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a crate::apis::openai::ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
// Since all providers use the same implementation, just use the direct implementation
|
||||
Some(&response.usage)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for Provider {
|
||||
type Error = openai::provider::OpenAIApiError;
|
||||
type StreamingIter = openai::provider::OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider_id: &super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Groq(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Mistral(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Deepseek(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Arch(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Gemini(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Claude(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::GitHub(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for Provider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::Groq(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::Mistral(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::Deepseek(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::Arch(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::Gemini(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::Claude(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::GitHub(provider, _) => provider.has_compatible_api(api_path),
|
||||
}
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => provider.supported_apis(),
|
||||
Provider::Groq(provider, _) => provider.supported_apis(),
|
||||
Provider::Mistral(provider, _) => provider.supported_apis(),
|
||||
Provider::Deepseek(provider, _) => provider.supported_apis(),
|
||||
Provider::Arch(provider, _) => provider.supported_apis(),
|
||||
Provider::Gemini(provider, _) => provider.supported_apis(),
|
||||
Provider::Claude(provider, _) => provider.supported_apis(),
|
||||
Provider::GitHub(provider, _) => provider.supported_apis(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,64 +77,58 @@ impl ProviderInterface for OpenAIProvider {
|
|||
}
|
||||
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::providers::traits::ProviderRequest;
|
||||
match ChatCompletionsRequest::try_from_bytes(bytes) {
|
||||
match ProviderRequest::try_from_bytes(self, bytes) {
|
||||
Ok(req) => Ok(req),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::providers::traits::ProviderResponse;
|
||||
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
|
||||
match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use crate::providers::traits::ProviderRequest;
|
||||
match request.to_provider_bytes(provider_id, mode) {
|
||||
match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Trait Implementations for OpenAI Types
|
||||
// ============================================================================
|
||||
|
||||
impl ProviderRequest for ChatCompletionsRequest {
|
||||
// Direct trait implementations on OpenAIProvider
|
||||
impl ProviderRequest for OpenAIProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(serde_json::from_str(s)?)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, _provider: super::super::ProviderId, _mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
Ok(serde_json::to_vec(self)?)
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, _provider: super::super::ProviderId, _mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
Ok(serde_json::to_vec(request)?)
|
||||
}
|
||||
|
||||
fn extract_model(&self) -> &str {
|
||||
&self.model
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self) -> bool {
|
||||
self.stream.unwrap_or_default()
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&mut self) {
|
||||
if self.stream_options.is_none() {
|
||||
self.stream_options = Some(StreamOptions {
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
if request.stream_options.is_none() {
|
||||
request.stream_options = Some(StreamOptions {
|
||||
include_usage: Some(true),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
self.messages
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
request.messages
|
||||
.iter()
|
||||
.fold(String::new(), |acc, m| {
|
||||
acc + " " + &match &m.content {
|
||||
|
|
@ -150,8 +144,41 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
}
|
||||
}
|
||||
|
||||
// Implement the helper trait for stream context integration
|
||||
impl crate::providers::traits::StreamContextHelpers for ChatCompletionsRequest {}
|
||||
impl ProviderResponse for OpenAIProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(serde_json::from_str(s)?)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
Some((
|
||||
response.usage.prompt_tokens as usize,
|
||||
response.usage.completion_tokens as usize,
|
||||
response.usage.total_tokens as usize,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for OpenAIProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(OpenAIStreamingResponse::new(s.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Trait Implementations for OpenAI Types (Keep for TokenUsage only)
|
||||
// ============================================================================
|
||||
|
||||
impl TokenUsage for Usage {
|
||||
fn completion_tokens(&self) -> usize {
|
||||
|
|
@ -167,20 +194,6 @@ impl TokenUsage for Usage {
|
|||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for ChatCompletionsResponse {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(serde_json::from_str(s)?)
|
||||
}
|
||||
|
||||
fn usage(&self) -> Option<&Self::Usage> {
|
||||
Some(&self.usage)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamChunk for ChatCompletionsStreamResponse {
|
||||
type Usage = Usage;
|
||||
|
||||
|
|
@ -191,9 +204,9 @@ impl StreamChunk for ChatCompletionsStreamResponse {
|
|||
|
||||
impl StreamingResponse for OpenAIStreamingResponse {
|
||||
type Error = OpenAIApiError;
|
||||
type Chunk = ChatCompletionsStreamResponse;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self, Self::Error> {
|
||||
fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(OpenAIStreamingResponse::new(s.to_string()))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,26 +15,26 @@ pub enum ConversionMode {
|
|||
}
|
||||
|
||||
/// Trait for provider-specific request types
|
||||
pub trait ProviderRequest: Sized {
|
||||
pub trait ProviderRequest {
|
||||
type Error: Error + Send + Sync + 'static;
|
||||
|
||||
/// Parse request from raw bytes
|
||||
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>;
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Self::Error>;
|
||||
|
||||
/// Convert to provider-specific format
|
||||
fn to_provider_bytes(&self, provider: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error>;
|
||||
fn to_provider_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error>;
|
||||
|
||||
/// Extract the model name from the request
|
||||
fn extract_model(&self) -> &str;
|
||||
fn extract_model<'a>(&self, request: &'a crate::apis::openai::ChatCompletionsRequest) -> &'a str;
|
||||
|
||||
/// Check if this is a streaming request
|
||||
fn is_streaming(&self) -> bool;
|
||||
fn is_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool;
|
||||
|
||||
/// Set streaming options (e.g., include_usage)
|
||||
fn set_streaming_options(&mut self);
|
||||
fn set_streaming_options(&self, request: &mut crate::apis::openai::ChatCompletionsRequest);
|
||||
|
||||
/// Extract text content from messages for token counting
|
||||
fn extract_messages_text(&self) -> String;
|
||||
fn extract_messages_text(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String;
|
||||
}
|
||||
|
||||
/// Trait for token usage information
|
||||
|
|
@ -45,39 +45,19 @@ pub trait TokenUsage {
|
|||
}
|
||||
|
||||
/// Trait for provider-specific response types
|
||||
pub trait ProviderResponse: Sized {
|
||||
pub trait ProviderResponse {
|
||||
type Error: Error + Send + Sync + 'static;
|
||||
type Usage: TokenUsage;
|
||||
|
||||
/// Parse response from raw bytes
|
||||
fn try_from_bytes(bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<Self, Self::Error>;
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Self::Error>;
|
||||
|
||||
/// Get usage information if available
|
||||
fn usage(&self) -> Option<&Self::Usage>;
|
||||
fn usage<'a>(&self, response: &'a crate::apis::openai::ChatCompletionsResponse) -> Option<&'a Self::Usage>;
|
||||
|
||||
/// Extract token counts for metrics
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper trait for stream context integration
|
||||
pub trait StreamContextHelpers: ProviderRequest {
|
||||
/// Get the model name for routing and metrics
|
||||
fn get_model_for_routing(&self) -> String {
|
||||
self.extract_model().to_string()
|
||||
}
|
||||
|
||||
/// Get text for token counting and rate limiting
|
||||
fn get_text_for_tokenization(&self) -> String {
|
||||
self.extract_messages_text()
|
||||
}
|
||||
|
||||
/// Prepare for streaming by setting appropriate options
|
||||
fn prepare_for_streaming(&mut self) {
|
||||
if self.is_streaming() {
|
||||
self.set_streaming_options();
|
||||
}
|
||||
fn extract_usage_counts(&self, response: &crate::apis::openai::ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
self.usage(response).map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -90,70 +70,34 @@ pub trait StreamChunk {
|
|||
}
|
||||
|
||||
/// Trait for streaming response iterators
|
||||
pub trait StreamingResponse: Iterator<Item = Result<Self::Chunk, Self::Error>> + Sized {
|
||||
pub trait StreamingResponse {
|
||||
type Error: Error + Send + Sync + 'static;
|
||||
type Chunk: StreamChunk;
|
||||
type StreamingIter: Iterator<Item = Result<crate::apis::openai::ChatCompletionsStreamResponse, Self::Error>>;
|
||||
|
||||
/// Parse streaming response from raw bytes
|
||||
fn try_from_bytes(bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<Self, Self::Error>;
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error>;
|
||||
}
|
||||
|
||||
/// Main provider interface trait
|
||||
pub trait ProviderInterface {
|
||||
/// Main provider interface trait - simplified to only essential methods
|
||||
pub trait ProviderInterface: ProviderRequest + ProviderResponse + StreamingResponse {
|
||||
/// Check if this provider has a compatible API with the client request
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool;
|
||||
|
||||
/// Get the interface implementation for this provider
|
||||
/// passthrough: if true, use provider-specific format; if false, use compatible format
|
||||
fn get_interface(&self, passthrough: bool) -> ConversionMode {
|
||||
if passthrough {
|
||||
ConversionMode::Passthrough
|
||||
} else {
|
||||
ConversionMode::Compatible
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a request from raw bytes - returns concrete ChatCompletionsRequest
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
/// Parse a response from raw bytes - returns concrete ChatCompletionsResponse
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
/// Convert a request to bytes for sending to upstream API
|
||||
fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
/// Extract model name from request for routing (convenience method for stream_context)
|
||||
fn extract_model_from_request(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String {
|
||||
use ProviderRequest;
|
||||
request.extract_model().to_string()
|
||||
}
|
||||
|
||||
/// Check if request is streaming (convenience method for stream_context)
|
||||
fn is_request_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool {
|
||||
use ProviderRequest;
|
||||
request.is_streaming()
|
||||
}
|
||||
|
||||
/// Prepare request for streaming (convenience method for stream_context)
|
||||
fn prepare_request_for_streaming(&self, request: &mut crate::apis::openai::ChatCompletionsRequest) {
|
||||
use ProviderRequest;
|
||||
if request.is_streaming() {
|
||||
request.set_streaming_options();
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract text for tokenization (convenience method for stream_context)
|
||||
fn extract_text_for_tokenization(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String {
|
||||
use ProviderRequest;
|
||||
request.extract_messages_text()
|
||||
}
|
||||
|
||||
/// Extract usage information from response (convenience method for stream_context)
|
||||
fn extract_usage_from_response(&self, response: &crate::apis::openai::ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
use ProviderResponse;
|
||||
response.extract_usage_counts()
|
||||
}
|
||||
|
||||
/// Get supported API endpoints for this provider
|
||||
fn supported_apis(&self) -> Vec<&'static str>;
|
||||
|
||||
/// Parse a request from raw bytes - delegates to ProviderRequest
|
||||
fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ProviderRequest::try_from_bytes(self, bytes).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
}
|
||||
|
||||
/// Parse a response from raw bytes - delegates to ProviderResponse
|
||||
fn parse_response(&self, bytes: &[u8], provider_id: super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
}
|
||||
|
||||
/// Convert a request to bytes - delegates to ProviderRequest
|
||||
fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ProviderRequest::to_provider_bytes(self, request, provider_id, mode).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ use common::ratelimit::Header;
|
|||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::apis::openai::{ContentPart, MessageContent};
|
||||
use hermesllm::providers::traits::{
|
||||
ProviderRequest, ProviderResponse, StreamChunk, StreamingResponse, TokenUsage,
|
||||
};
|
||||
use hermesllm::{ConversionMode, Provider, ProviderId};
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
|
|
@ -39,6 +43,7 @@ pub struct StreamContext {
|
|||
request_body_sent_time: Option<u128>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
user_message: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -66,6 +71,7 @@ impl StreamContext {
|
|||
ttft_time: None,
|
||||
traces_queue,
|
||||
request_body_sent_time: None,
|
||||
user_message: None,
|
||||
}
|
||||
}
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
|
|
@ -295,7 +301,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let provider = self.get_provider();
|
||||
|
||||
let mut deserialized_body = match provider.interface().parse_request(&body_bytes) {
|
||||
let mut deserialized_body = match ProviderRequest::try_from_bytes(&provider, &body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
|
|
@ -324,9 +330,29 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
// Use the provider interface methods for cleaner interaction
|
||||
let model_requested = provider
|
||||
.interface()
|
||||
.extract_model_from_request(&deserialized_body);
|
||||
let model_requested = provider.extract_model(&deserialized_body).to_string(); // Convert to owned string
|
||||
|
||||
// Extract user message for tracing
|
||||
self.user_message = deserialized_body.messages.last().and_then(|msg| {
|
||||
match &msg.content {
|
||||
MessageContent::Text(text) => Some(text.clone()),
|
||||
MessageContent::Parts(parts) => {
|
||||
// Extract text from content parts, ignoring images
|
||||
let text_parts: Vec<String> = parts
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.clone()),
|
||||
ContentPart::ImageUrl { .. } => None,
|
||||
})
|
||||
.collect();
|
||||
if text_parts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text_parts.join(" "))
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||
|
|
@ -336,20 +362,15 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
|
||||
// Use provider interface for streaming detection and setup
|
||||
if provider
|
||||
.interface()
|
||||
.is_request_streaming(&deserialized_body)
|
||||
{
|
||||
self.streaming_response = true;
|
||||
provider
|
||||
.interface()
|
||||
.prepare_request_for_streaming(&mut deserialized_body);
|
||||
self.streaming_response = provider.is_streaming(&deserialized_body);
|
||||
|
||||
// Set streaming options if needed
|
||||
if self.streaming_response {
|
||||
provider.set_streaming_options(&mut deserialized_body);
|
||||
}
|
||||
|
||||
// Use provider interface for text extraction
|
||||
let input_tokens_str = provider
|
||||
.interface()
|
||||
.extract_text_for_tokenization(&deserialized_body);
|
||||
// Use provider interface for text extraction (after potential mutation)
|
||||
let input_tokens_str = provider.extract_messages_text(&deserialized_body);
|
||||
// enforce ratelimits on ingress
|
||||
if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) {
|
||||
self.send_server_error(
|
||||
|
|
@ -364,7 +385,7 @@ impl HttpContext for StreamContext {
|
|||
let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||
|
||||
// Convert chat completion request to llm provider specific request using provider interface
|
||||
let deserialized_body_bytes = match provider.interface().request_to_bytes(
|
||||
let deserialized_body_bytes = match provider.to_provider_bytes(
|
||||
&deserialized_body,
|
||||
provider.id(),
|
||||
ConversionMode::Compatible,
|
||||
|
|
@ -473,6 +494,11 @@ impl HttpContext for StreamContext {
|
|||
self.llm_provider().name.to_string(),
|
||||
);
|
||||
|
||||
if let Some(user_message) = &self.user_message {
|
||||
llm_span
|
||||
.add_attribute("user_message".to_string(), user_message.clone());
|
||||
}
|
||||
|
||||
if self.ttft_time.is_some() {
|
||||
llm_span.add_event(Event::new(
|
||||
"time_to_first_token".to_string(),
|
||||
|
|
@ -540,36 +566,74 @@ impl HttpContext for StreamContext {
|
|||
let _provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||
|
||||
if self.streaming_response {
|
||||
// TODO: Implement streaming response parsing with new provider structure
|
||||
warn!(
|
||||
"Streaming response parsing not yet fully implemented with new provider structure"
|
||||
);
|
||||
debug!("processing streaming response");
|
||||
|
||||
// For now, just compute TTFT and continue
|
||||
if self.ttft_duration.is_none() {
|
||||
let current_time = get_current_time().unwrap();
|
||||
self.ttft_time = Some(current_time_ns());
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
let duration_ms = duration.as_millis();
|
||||
info!(
|
||||
"on_http_response_body: time to first token: {}ms",
|
||||
duration_ms
|
||||
);
|
||||
self.ttft_duration = Some(duration);
|
||||
self.metrics.time_to_first_token.record(duration_ms as u64);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
// Parse streaming response using OpenAI-compatible format
|
||||
// Since all providers use OpenAI-compatible streaming format
|
||||
let provider = self.get_provider();
|
||||
let provider_id =
|
||||
ProviderId::from(self.llm_provider().provider_interface.to_string().as_str());
|
||||
|
||||
match StreamingResponse::try_from_bytes(
|
||||
&provider,
|
||||
&body,
|
||||
&provider_id,
|
||||
ConversionMode::Compatible,
|
||||
) {
|
||||
Ok(mut streaming_response) => {
|
||||
// Process each streaming chunk
|
||||
while let Some(chunk_result) = streaming_response.next() {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
// Compute TTFT on first chunk
|
||||
if self.ttft_duration.is_none() {
|
||||
let current_time = get_current_time().unwrap();
|
||||
self.ttft_time = Some(current_time_ns());
|
||||
match current_time.duration_since(self.start_time) {
|
||||
Ok(duration) => {
|
||||
let duration_ms = duration.as_millis();
|
||||
info!(
|
||||
"on_http_response_body: time to first token: {}ms",
|
||||
duration_ms
|
||||
);
|
||||
self.ttft_duration = Some(duration);
|
||||
self.metrics
|
||||
.time_to_first_token
|
||||
.record(duration_ms as u64);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract usage information if available
|
||||
if let Some(usage) = chunk.usage() {
|
||||
let completion_tokens = usage.completion_tokens();
|
||||
self.response_tokens += completion_tokens;
|
||||
debug!(
|
||||
"Streaming chunk completion tokens: {}",
|
||||
completion_tokens
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error processing streaming chunk: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse streaming response: {}", e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
let provider = self.get_provider();
|
||||
let response = match provider.interface().parse_response(
|
||||
let response = match ProviderResponse::try_from_bytes(
|
||||
&provider,
|
||||
&body,
|
||||
provider.id(),
|
||||
&provider.id(),
|
||||
ConversionMode::Compatible,
|
||||
) {
|
||||
Ok(response) => response,
|
||||
|
|
@ -594,7 +658,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
// Use provider interface to extract usage information
|
||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
||||
provider.interface().extract_usage_from_response(&response)
|
||||
provider.extract_usage_counts(&response)
|
||||
{
|
||||
debug!(
|
||||
"Response usage: prompt={}, completion={}, total={}",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue