mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
more refactoring to clean code and make stream_context.rs work
This commit is contained in:
parent
d4ca70d177
commit
df3aa17d67
23 changed files with 545 additions and 1321 deletions
|
|
@ -178,13 +178,10 @@ impl Display for LlmProviderType {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlmProviderType {
|
impl LlmProviderType {
|
||||||
/// Create a Provider from this LlmProviderType
|
/// Get the ProviderId for this LlmProviderType
|
||||||
/// This is the main method for stream_context to get provider-specific interfaces
|
/// Used with the new function-based hermesllm API
|
||||||
pub fn create_provider(&self) -> hermesllm::Provider {
|
pub fn to_provider_id(&self) -> hermesllm::ProviderId {
|
||||||
use hermesllm::{ProviderId, Provider};
|
hermesllm::ProviderId::from(self.to_string().as_str())
|
||||||
|
|
||||||
let provider_id = ProviderId::from(self.to_string().as_str());
|
|
||||||
Provider::new(provider_id)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -264,10 +261,10 @@ impl Display for LlmProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlmProvider {
|
impl LlmProvider {
|
||||||
/// Create a Provider for this LlmProvider
|
/// Get the ProviderId for this LlmProvider
|
||||||
/// This is a convenience method that delegates to the provider_interface
|
/// Used with the new function-based hermesllm API
|
||||||
pub fn create_provider(&self) -> hermesllm::Provider {
|
pub fn to_provider_id(&self) -> hermesllm::ProviderId {
|
||||||
self.provider_interface.create_provider()
|
self.provider_interface.to_provider_id()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,15 +35,17 @@ _Replace the path with the appropriate location if using as a workspace member o
|
||||||
Construct a chat completion request using the builder pattern:
|
Construct a chat completion request using the builder pattern:
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
use hermesllm::Provider;
|
use hermesllm::{create_provider, ProviderId};
|
||||||
use hermesllm::providers::openai::types::ChatCompletionsRequest;
|
use hermesllm::providers::openai::types::ChatCompletionsRequest;
|
||||||
|
|
||||||
let request = ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![Message::new("Hi".to_string())])
|
let request = ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![Message::new("Hi".to_string())])
|
||||||
.build()
|
.build()
|
||||||
.expect("Failed to build OpenAIRequest");
|
.expect("Failed to build OpenAIRequest");
|
||||||
|
|
||||||
// Convert to bytes for a specific provider
|
// Create a provider and convert request to bytes
|
||||||
let bytes = request.to_bytes(Provider::OpenAI)?;
|
let provider = create_provider(ProviderId::OpenAI);
|
||||||
|
let bytes = serde_json::to_vec(&request)?;
|
||||||
|
let parsed_request = provider.try_request_from_bytes(&bytes)?;
|
||||||
```
|
```
|
||||||
|
|
||||||
## API Overview
|
## API Overview
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ use serde_json::Value;
|
||||||
use serde_with::skip_serializing_none;
|
use serde_with::skip_serializing_none;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::{providers::ProviderRequestError, ConversionMode, ProviderRequest};
|
||||||
use super::ApiDefinition;
|
use super::ApiDefinition;
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
@ -424,6 +425,212 @@ pub struct StreamOptions {
|
||||||
pub include_usage: Option<bool>,
|
pub include_usage: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// ============================================================================
|
||||||
|
/// OpenAI Provider Request Wrapper
|
||||||
|
/// ============================================================================
|
||||||
|
impl ProviderRequest for ChatCompletionsRequest {
|
||||||
|
fn model(&self) -> &str {
|
||||||
|
&self.model
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_streaming(&self) -> bool {
|
||||||
|
self.stream.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_streaming_options(&mut self) {
|
||||||
|
self.stream = Some(true);
|
||||||
|
if self.stream_options.is_none() {
|
||||||
|
self.stream_options = Some(StreamOptions { include_usage: Some(true) });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_messages_text(&self) -> String {
|
||||||
|
self.messages.iter().fold(String::new(), |acc, m| {
|
||||||
|
acc + " " + &match &m.content {
|
||||||
|
MessageContent::Text(text) => text.clone(),
|
||||||
|
MessageContent::Parts(parts) => parts.iter().map(|part| match part {
|
||||||
|
ContentPart::Text { text } => text.clone(),
|
||||||
|
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
|
||||||
|
}).collect::<Vec<_>>().join(" ")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_user_message(&self) -> Option<String> {
|
||||||
|
self.messages.last().and_then(|msg| {
|
||||||
|
match &msg.content {
|
||||||
|
MessageContent::Text(text) => Some(text.clone()),
|
||||||
|
MessageContent::Parts(_) => None, // No user message in parts
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_provider_bytes(&self, mode: ConversionMode) -> Result<Vec<u8>, ProviderRequestError> {
|
||||||
|
match mode {
|
||||||
|
ConversionMode::Compatible | ConversionMode::Passthrough => {
|
||||||
|
serde_json::to_vec(&self).map_err(|e| ProviderRequestError {
|
||||||
|
message: format!("Failed to serialize OpenAI request: {}", e),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// STREAMING SUPPORT
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
use crate::providers::traits::{ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, TokenUsage};
|
||||||
|
|
||||||
|
// Direct implementation of ProviderResponse on ChatCompletionsResponse
|
||||||
|
impl ProviderResponse for ChatCompletionsResponse {
|
||||||
|
fn usage(&self) -> Option<&dyn TokenUsage> {
|
||||||
|
Some(&self.usage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation of TokenUsage for OpenAI Usage type
|
||||||
|
impl TokenUsage for Usage {
|
||||||
|
fn completion_tokens(&self) -> usize {
|
||||||
|
self.completion_tokens as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prompt_tokens(&self) -> usize {
|
||||||
|
self.prompt_tokens as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
fn total_tokens(&self) -> usize {
|
||||||
|
self.total_tokens as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse
|
||||||
|
impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
||||||
|
fn content_delta(&self) -> Option<&str> {
|
||||||
|
self.choices
|
||||||
|
.first()
|
||||||
|
.and_then(|choice| choice.delta.content.as_deref())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_final(&self) -> bool {
|
||||||
|
self.choices
|
||||||
|
.first()
|
||||||
|
.map(|choice| choice.finish_reason.is_some())
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn role(&self) -> Option<&str> {
|
||||||
|
self.choices
|
||||||
|
.first()
|
||||||
|
.and_then(|choice| choice.delta.role.as_ref().map(|r| match r {
|
||||||
|
Role::System => "system",
|
||||||
|
Role::User => "user",
|
||||||
|
Role::Assistant => "assistant",
|
||||||
|
Role::Tool => "tool",
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error type for streaming operations
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum OpenAIStreamError {
|
||||||
|
#[error("JSON parsing error: {0}")]
|
||||||
|
JsonError(#[from] serde_json::Error),
|
||||||
|
#[error("UTF-8 parsing error: {0}")]
|
||||||
|
Utf8Error(#[from] std::str::Utf8Error),
|
||||||
|
#[error("Invalid streaming data: {0}")]
|
||||||
|
InvalidStreamingData(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// SSE-based streaming iterator for OpenAI chat completions
|
||||||
|
/// Implements ProviderStreamResponseIter directly
|
||||||
|
pub struct SseChatCompletionIter<I>
|
||||||
|
where
|
||||||
|
I: Iterator,
|
||||||
|
I::Item: AsRef<str>,
|
||||||
|
{
|
||||||
|
lines: I,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> SseChatCompletionIter<I>
|
||||||
|
where
|
||||||
|
I: Iterator,
|
||||||
|
I::Item: AsRef<str>,
|
||||||
|
{
|
||||||
|
pub fn new(lines: I) -> Self {
|
||||||
|
Self { lines }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> Iterator for SseChatCompletionIter<I>
|
||||||
|
where
|
||||||
|
I: Iterator,
|
||||||
|
I::Item: AsRef<str>,
|
||||||
|
{
|
||||||
|
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
for line in &mut self.lines {
|
||||||
|
let line = line.as_ref();
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if line.starts_with("data: ") {
|
||||||
|
let data = &line[6..]; // Remove "data: " prefix
|
||||||
|
if data == "[DONE]" {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if data == r#"{"type": "ping"}"# {
|
||||||
|
continue; // Skip ping messages - that is usually from anthropic
|
||||||
|
}
|
||||||
|
|
||||||
|
match serde_json::from_str::<ChatCompletionsStreamResponse>(data) {
|
||||||
|
Ok(response) => return Some(Ok(Box::new(response))),
|
||||||
|
Err(e) => return Some(Err(Box::new(
|
||||||
|
OpenAIStreamError::InvalidStreamingData(format!("Error parsing: {}, data: {}", e, data))
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> ProviderStreamResponseIter for SseChatCompletionIter<I>
|
||||||
|
where
|
||||||
|
I: Iterator + Send + Sync,
|
||||||
|
I::Item: AsRef<str>,
|
||||||
|
{
|
||||||
|
// Just marking that this type implements the trait - no additional methods needed
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// PARAMETERIZED CONVERSIONS FOR PROVIDER FUNCTIONS
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
use crate::providers::ProviderId;
|
||||||
|
|
||||||
|
/// Parameterized conversion for ChatCompletionsRequest
|
||||||
|
impl TryFrom<(&[u8], &ProviderId)> for ChatCompletionsRequest {
|
||||||
|
type Error = OpenAIStreamError;
|
||||||
|
|
||||||
|
fn try_from((bytes, _provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
||||||
|
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parameterized conversion for ChatCompletionsResponse
|
||||||
|
impl TryFrom<(&[u8], &ProviderId)> for ChatCompletionsResponse {
|
||||||
|
type Error = OpenAIStreamError;
|
||||||
|
|
||||||
|
fn try_from((bytes, _provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
||||||
|
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,11 @@ pub mod clients;
|
||||||
|
|
||||||
// Re-export important types and traits
|
// Re-export important types and traits
|
||||||
pub use providers::{
|
pub use providers::{
|
||||||
ProviderId, Provider, ConversionMode,
|
ProviderId, ConversionMode,
|
||||||
ProviderInterface, ProviderRequest, ProviderResponse,
|
ProviderRequest, ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter,
|
||||||
TokenUsage, StreamChunk, StreamingResponse,
|
TokenUsage,
|
||||||
OpenAIProvider
|
try_request_from_bytes, try_response_from_bytes, try_streaming_from_bytes,
|
||||||
|
has_compatible_api, supported_apis
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -26,70 +27,71 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_provider_api_paths() {
|
fn test_provider_api_compatibility() {
|
||||||
assert_eq!(ProviderId::OpenAI.api_path(), "/v1/chat/completions");
|
assert!(has_compatible_api(&ProviderId::OpenAI, "/v1/chat/completions"));
|
||||||
assert_eq!(ProviderId::Groq.api_path(), "/openai/v1/chat/completions");
|
assert!(!has_compatible_api(&ProviderId::OpenAI, "/v1/embeddings"));
|
||||||
assert_eq!(ProviderId::Mistral.api_path(), "/v1/chat/completions");
|
|
||||||
assert_eq!(ProviderId::Arch.api_path(), "/v1/chat/completions");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_provider_openai_format_support() {
|
|
||||||
assert!(ProviderId::OpenAI.supports_openai_format());
|
|
||||||
assert!(ProviderId::Groq.supports_openai_format());
|
|
||||||
assert!(ProviderId::Mistral.supports_openai_format());
|
|
||||||
assert!(ProviderId::Arch.supports_openai_format());
|
|
||||||
assert!(!ProviderId::Gemini.supports_openai_format());
|
|
||||||
assert!(!ProviderId::Claude.supports_openai_format());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_provider_instance_creation() {
|
|
||||||
let provider = Provider::new(ProviderId::OpenAI);
|
|
||||||
assert!(provider.has_compatible_api("/v1/chat/completions"));
|
|
||||||
assert!(!provider.has_compatible_api("/v1/embeddings"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_provider_supported_apis() {
|
fn test_provider_supported_apis() {
|
||||||
let provider = Provider::new(ProviderId::OpenAI);
|
let apis = supported_apis(&ProviderId::OpenAI);
|
||||||
|
assert!(apis.contains(&"/v1/chat/completions"));
|
||||||
let supported_apis = provider.supported_apis();
|
|
||||||
assert!(supported_apis.contains(&"/v1/chat/completions"));
|
|
||||||
|
|
||||||
// Test that provider supports the expected API endpoints
|
// Test that provider supports the expected API endpoints
|
||||||
assert!(provider.has_compatible_api("/v1/chat/completions"));
|
assert!(has_compatible_api(&ProviderId::OpenAI, "/v1/chat/completions"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_provider_extract_user_message() {
|
fn test_provider_request_parsing() {
|
||||||
use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent};
|
// Test with a sample JSON request
|
||||||
|
let json_request = r#"{
|
||||||
let provider = Provider::new(ProviderId::OpenAI);
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
// Test with text message
|
{
|
||||||
let request = ChatCompletionsRequest {
|
"role": "system",
|
||||||
model: "gpt-4".to_string(),
|
"content": "You are a helpful assistant"
|
||||||
messages: vec![
|
|
||||||
Message {
|
|
||||||
role: crate::apis::openai::Role::System,
|
|
||||||
content: MessageContent::Text("You are a helpful assistant".to_string()),
|
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
},
|
||||||
Message {
|
{
|
||||||
role: crate::apis::openai::Role::User,
|
"role": "user",
|
||||||
content: MessageContent::Text("Hello, world!".to_string()),
|
"content": "Hello!"
|
||||||
name: None,
|
}
|
||||||
tool_calls: None,
|
]
|
||||||
tool_call_id: None,
|
}"#;
|
||||||
},
|
|
||||||
],
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let user_message = provider.extract_user_message(&request);
|
let result = try_request_from_bytes(json_request.as_bytes(), &ProviderId::OpenAI);
|
||||||
assert_eq!(user_message, Some("Hello, world!".to_string()));
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let request = result.unwrap();
|
||||||
|
assert_eq!(request.model(), "gpt-4");
|
||||||
|
assert_eq!(request.extract_user_message(), Some("Hello!".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_provider_streaming_response() {
|
||||||
|
// Test streaming response parsing with sample SSE data
|
||||||
|
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let result = try_streaming_from_bytes(sse_data.as_bytes(), &ProviderId::OpenAI, ConversionMode::Passthrough);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let mut streaming_response = result.unwrap();
|
||||||
|
|
||||||
|
// Test that we can iterate over chunks - it's just an iterator now!
|
||||||
|
let first_chunk = streaming_response.next();
|
||||||
|
assert!(first_chunk.is_some());
|
||||||
|
|
||||||
|
let chunk_result = first_chunk.unwrap();
|
||||||
|
assert!(chunk_result.is_ok());
|
||||||
|
|
||||||
|
let chunk = chunk_result.unwrap();
|
||||||
|
assert_eq!(chunk.content_delta(), Some("Hello"));
|
||||||
|
assert!(!chunk.is_final());
|
||||||
|
|
||||||
|
// Test that stream ends properly
|
||||||
|
let final_chunk = streaming_response.next();
|
||||||
|
assert!(final_chunk.is_none());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
//! Arch provider implementation
|
|
||||||
//!
|
|
||||||
//! Arch uses OpenAI-compatible API format
|
|
||||||
|
|
||||||
pub mod provider;
|
|
||||||
pub use provider::ArchProvider;
|
|
||||||
|
|
@ -1,88 +0,0 @@
|
||||||
//! Arch provider implementation
|
|
||||||
|
|
||||||
use crate::providers::{ProviderInterface, ConversionMode};
|
|
||||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
|
||||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
|
||||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
|
||||||
|
|
||||||
/// Arch provider implementation
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct ArchProvider;
|
|
||||||
|
|
||||||
// Trait implementations that delegate to OpenAI
|
|
||||||
impl ProviderRequest for ArchProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
|
||||||
&request.model
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
|
||||||
request.stream.unwrap_or_default()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderResponse for ArchProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type Usage = Usage;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
|
||||||
Some(&response.usage)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamingResponse for ArchProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
|
||||||
type StreamingIter = OpenAIStreamingResponse;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderInterface for ArchProvider {
|
|
||||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
|
||||||
matches!(api_path, "/v1/chat/completions")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supported_apis(&self) -> Vec<&'static str> {
|
|
||||||
vec!["/v1/chat/completions"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
//! Claude provider implementation
|
|
||||||
//!
|
|
||||||
//! Claude will use a different API format in the future (/v1/messages)
|
|
||||||
//! For now, fallback to OpenAI-compatible format
|
|
||||||
|
|
||||||
pub mod provider;
|
|
||||||
pub use provider::ClaudeProvider;
|
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
//! Claude provider implementation
|
|
||||||
//!
|
|
||||||
//! TODO: Implement Claude-specific API format (/v1/messages) when needed
|
|
||||||
//! For now, uses OpenAI-compatible format as fallback
|
|
||||||
|
|
||||||
use crate::providers::{ProviderInterface, ConversionMode};
|
|
||||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
|
||||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
|
||||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
|
||||||
|
|
||||||
/// Claude provider implementation
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct ClaudeProvider;
|
|
||||||
|
|
||||||
// Trait implementations that delegate to OpenAI
|
|
||||||
impl ProviderRequest for ClaudeProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
|
||||||
&request.model
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
|
||||||
request.stream.unwrap_or_default()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderResponse for ClaudeProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type Usage = Usage;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
|
||||||
Some(&response.usage)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamingResponse for ClaudeProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
|
||||||
type StreamingIter = OpenAIStreamingResponse;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderInterface for ClaudeProvider {
|
|
||||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
|
||||||
// TODO: Update when Claude API is fully implemented
|
|
||||||
matches!(api_path, "/v1/chat/completions" | "/v1/messages")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supported_apis(&self) -> Vec<&'static str> {
|
|
||||||
// TODO: Update when Claude API is fully implemented
|
|
||||||
vec!["/v1/messages"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
//! Deepseek provider implementation
|
|
||||||
//!
|
|
||||||
//! Deepseek uses OpenAI-compatible API format
|
|
||||||
|
|
||||||
pub mod provider;
|
|
||||||
pub use provider::DeepseekProvider;
|
|
||||||
|
|
@ -1,88 +0,0 @@
|
||||||
//! Deepseek provider implementation
|
|
||||||
|
|
||||||
use crate::providers::{ProviderInterface, ConversionMode};
|
|
||||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
|
||||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
|
||||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
|
||||||
|
|
||||||
/// Deepseek provider implementation
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct DeepseekProvider;
|
|
||||||
|
|
||||||
// Trait implementations that delegate to OpenAI
|
|
||||||
impl ProviderRequest for DeepseekProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
|
||||||
&request.model
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
|
||||||
request.stream.unwrap_or_default()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderResponse for DeepseekProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type Usage = Usage;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
|
||||||
Some(&response.usage)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamingResponse for DeepseekProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
|
||||||
type StreamingIter = OpenAIStreamingResponse;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderInterface for DeepseekProvider {
|
|
||||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
|
||||||
matches!(api_path, "/v1/chat/completions")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supported_apis(&self) -> Vec<&'static str> {
|
|
||||||
vec!["/v1/chat/completions"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
//! Gemini provider implementation
|
|
||||||
//!
|
|
||||||
//! Gemini will use a different API format in the future
|
|
||||||
//! For now, fallback to OpenAI-compatible format
|
|
||||||
|
|
||||||
pub mod provider;
|
|
||||||
pub use provider::GeminiProvider;
|
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
//! Gemini provider implementation
|
|
||||||
//!
|
|
||||||
//! This module contains the Gemini provider that handles Google's Gemini API format
|
|
||||||
//! requests in OpenAI-compatible format.
|
|
||||||
|
|
||||||
use crate::providers::{ProviderInterface, ConversionMode};
|
|
||||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
|
||||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
|
||||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
|
||||||
|
|
||||||
/// Gemini provider implementation
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct GeminiProvider;
|
|
||||||
|
|
||||||
// Trait implementations that delegate to OpenAI
|
|
||||||
impl ProviderRequest for GeminiProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
|
||||||
&request.model
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
|
||||||
request.stream.unwrap_or_default()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderResponse for GeminiProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type Usage = Usage;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
|
||||||
Some(&response.usage)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamingResponse for GeminiProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
|
||||||
type StreamingIter = OpenAIStreamingResponse;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderInterface for GeminiProvider {
|
|
||||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
|
||||||
// TODO: Update when Gemini API is fully implemented
|
|
||||||
matches!(api_path, "/v1/chat/completions" | "/v1/models")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supported_apis(&self) -> Vec<&'static str> {
|
|
||||||
// TODO: Update when Gemini API is fully implemented
|
|
||||||
vec!["/v1/models"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
//! GitHub provider implementation
|
|
||||||
//!
|
|
||||||
//! GitHub will use a different API format in the future (/models)
|
|
||||||
//! For now, fallback to OpenAI-compatible format
|
|
||||||
|
|
||||||
pub mod provider;
|
|
||||||
pub use provider::GitHubProvider;
|
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
//! GitHub provider implementation
|
|
||||||
//!
|
|
||||||
//! This module contains the GitHub provider that handles GitHub API format
|
|
||||||
//! requests in OpenAI-compatible format.
|
|
||||||
|
|
||||||
use crate::providers::{ProviderInterface, ConversionMode};
|
|
||||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
|
||||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
|
||||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
|
||||||
|
|
||||||
/// GitHub provider implementation
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct GitHubProvider;
|
|
||||||
|
|
||||||
// Trait implementations that delegate to OpenAI
|
|
||||||
impl ProviderRequest for GitHubProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
|
||||||
&request.model
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
|
||||||
request.stream.unwrap_or_default()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderResponse for GitHubProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type Usage = Usage;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
|
||||||
Some(&response.usage)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamingResponse for GitHubProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
|
||||||
type StreamingIter = OpenAIStreamingResponse;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderInterface for GitHubProvider {
|
|
||||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
|
||||||
// TODO: Update when GitHub API is fully implemented
|
|
||||||
matches!(api_path, "/v1/chat/completions" | "/models")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supported_apis(&self) -> Vec<&'static str> {
|
|
||||||
// TODO: Update when GitHub API is fully implemented
|
|
||||||
vec!["/models"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
//! Groq provider implementation
|
|
||||||
//!
|
|
||||||
//! Groq uses OpenAI-compatible API format but with different endpoints
|
|
||||||
|
|
||||||
pub mod provider;
|
|
||||||
pub use provider::GroqProvider;
|
|
||||||
|
|
@ -1,91 +0,0 @@
|
||||||
//! Groq provider implementation
|
|
||||||
//!
|
|
||||||
//! This module contains the Groq provider that handles Groq API format requests.
|
|
||||||
//! Groq uses OpenAI-compatible format but may have provider-specific nuances.
|
|
||||||
|
|
||||||
use crate::providers::{ProviderInterface, ConversionMode};
|
|
||||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
|
||||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
|
||||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
|
||||||
|
|
||||||
/// Groq provider implementation
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct GroqProvider;
|
|
||||||
|
|
||||||
// Trait implementations that delegate to OpenAI
|
|
||||||
impl ProviderRequest for GroqProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
|
||||||
&request.model
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
|
||||||
request.stream.unwrap_or_default()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderResponse for GroqProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type Usage = Usage;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
|
||||||
Some(&response.usage)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamingResponse for GroqProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
|
||||||
type StreamingIter = OpenAIStreamingResponse;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderInterface for GroqProvider {
|
|
||||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
|
||||||
matches!(api_path, "/v1/chat/completions" | "/openai/v1/chat/completions")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supported_apis(&self) -> Vec<&'static str> {
|
|
||||||
vec!["/openai/v1/chat/completions"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
//! Mistral provider implementation
|
|
||||||
//!
|
|
||||||
//! Mistral uses OpenAI-compatible API format
|
|
||||||
|
|
||||||
pub mod provider;
|
|
||||||
pub use provider::MistralProvider;
|
|
||||||
|
|
@ -1,88 +0,0 @@
|
||||||
//! Mistral provider implementation
|
|
||||||
|
|
||||||
use crate::providers::{ProviderInterface, ConversionMode};
|
|
||||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
|
||||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
|
||||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
|
||||||
|
|
||||||
/// Mistral provider implementation
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct MistralProvider;
|
|
||||||
|
|
||||||
// Trait implementations that delegate to OpenAI
|
|
||||||
impl ProviderRequest for MistralProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
|
||||||
&request.model
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
|
||||||
request.stream.unwrap_or_default()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderResponse for MistralProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type Usage = Usage;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
|
||||||
Some(&response.usage)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamingResponse for MistralProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
|
||||||
type StreamingIter = OpenAIStreamingResponse;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
|
||||||
let openai_provider = OpenAIProvider;
|
|
||||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderInterface for MistralProvider {
|
|
||||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
|
||||||
matches!(api_path, "/v1/chat/completions")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supported_apis(&self) -> Vec<&'static str> {
|
|
||||||
vec!["/v1/chat/completions"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -5,24 +5,11 @@
|
||||||
|
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
pub mod groq;
|
|
||||||
pub mod mistral;
|
|
||||||
pub mod deepseek;
|
|
||||||
pub mod arch;
|
|
||||||
pub mod gemini;
|
|
||||||
pub mod claude;
|
|
||||||
pub mod github;
|
|
||||||
|
|
||||||
// Re-export the main interfaces
|
// Re-export the main interfaces
|
||||||
pub use traits::*;
|
pub use traits::*;
|
||||||
pub use openai::OpenAIProvider;
|
// Note: OpenAIProvider has been deprecated in favor of function-based approach
|
||||||
pub use groq::GroqProvider;
|
// OpenAI functionality is accessed through openai::builder and openai::types modules
|
||||||
pub use mistral::MistralProvider;
|
|
||||||
pub use deepseek::DeepseekProvider;
|
|
||||||
pub use arch::ArchProvider;
|
|
||||||
pub use gemini::GeminiProvider;
|
|
||||||
pub use claude::ClaudeProvider;
|
|
||||||
pub use github::GitHubProvider;
|
|
||||||
|
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
|
@ -69,219 +56,3 @@ impl Display for ProviderId {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ProviderId {
|
|
||||||
/// Get the API endpoint path for this provider
|
|
||||||
pub fn api_path(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
ProviderId::OpenAI => "/v1/chat/completions",
|
|
||||||
ProviderId::Groq => "/openai/v1/chat/completions",
|
|
||||||
ProviderId::Gemini => "/v1/models", // TODO: Update when Gemini API is implemented
|
|
||||||
ProviderId::Claude => "/v1/messages", // TODO: Update when Claude API is implemented
|
|
||||||
ProviderId::Mistral => "/v1/chat/completions",
|
|
||||||
ProviderId::Deepseek => "/v1/chat/completions",
|
|
||||||
ProviderId::GitHub => "/models", // TODO: Update when GitHub models API is implemented
|
|
||||||
ProviderId::Arch => "/v1/chat/completions",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if this provider supports OpenAI v1/chat/completions API format
|
|
||||||
pub fn supports_openai_format(&self) -> bool {
|
|
||||||
matches!(
|
|
||||||
self,
|
|
||||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Enum for dynamic dispatch of provider instances
|
|
||||||
pub enum Provider {
|
|
||||||
OpenAI(OpenAIProvider, ProviderId),
|
|
||||||
Groq(GroqProvider, ProviderId),
|
|
||||||
Mistral(MistralProvider, ProviderId),
|
|
||||||
Deepseek(DeepseekProvider, ProviderId),
|
|
||||||
Arch(ArchProvider, ProviderId),
|
|
||||||
Gemini(GeminiProvider, ProviderId),
|
|
||||||
Claude(ClaudeProvider, ProviderId),
|
|
||||||
GitHub(GitHubProvider, ProviderId),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Provider {
|
|
||||||
/// Create a provider instance from a provider ID
|
|
||||||
pub fn new(id: ProviderId) -> Self {
|
|
||||||
match id {
|
|
||||||
ProviderId::OpenAI => Provider::OpenAI(OpenAIProvider, id),
|
|
||||||
ProviderId::Groq => Provider::Groq(GroqProvider, id),
|
|
||||||
ProviderId::Mistral => Provider::Mistral(MistralProvider, id),
|
|
||||||
ProviderId::Deepseek => Provider::Deepseek(DeepseekProvider, id),
|
|
||||||
ProviderId::Arch => Provider::Arch(ArchProvider, id),
|
|
||||||
ProviderId::Gemini => Provider::Gemini(GeminiProvider, id),
|
|
||||||
ProviderId::Claude => Provider::Claude(ClaudeProvider, id),
|
|
||||||
ProviderId::GitHub => Provider::GitHub(GitHubProvider, id),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the provider ID
|
|
||||||
pub fn id(&self) -> ProviderId {
|
|
||||||
match self {
|
|
||||||
Provider::OpenAI(_, id) => *id,
|
|
||||||
Provider::Groq(_, id) => *id,
|
|
||||||
Provider::Mistral(_, id) => *id,
|
|
||||||
Provider::Deepseek(_, id) => *id,
|
|
||||||
Provider::Arch(_, id) => *id,
|
|
||||||
Provider::Gemini(_, id) => *id,
|
|
||||||
Provider::Claude(_, id) => *id,
|
|
||||||
Provider::GitHub(_, id) => *id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implement traits directly on the Provider enum
|
|
||||||
impl ProviderRequest for Provider {
|
|
||||||
type Error = openai::provider::OpenAIApiError;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Self::Error> {
|
|
||||||
match self {
|
|
||||||
Provider::OpenAI(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
|
||||||
Provider::Groq(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
|
||||||
Provider::Mistral(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
|
||||||
Provider::Deepseek(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
|
||||||
Provider::Arch(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
|
||||||
Provider::Gemini(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
|
||||||
Provider::Claude(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
|
||||||
Provider::GitHub(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_provider_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
|
||||||
match self {
|
|
||||||
Provider::OpenAI(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
|
||||||
Provider::Groq(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
|
||||||
Provider::Mistral(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
|
||||||
Provider::Deepseek(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
|
||||||
Provider::Arch(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
|
||||||
Provider::Gemini(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
|
||||||
Provider::Claude(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
|
||||||
Provider::GitHub(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_model<'a>(&self, request: &'a crate::apis::openai::ChatCompletionsRequest) -> &'a str {
|
|
||||||
// Since all providers use the same implementation, just use the first one
|
|
||||||
&request.model
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool {
|
|
||||||
// Since all providers use the same implementation, just use the first one
|
|
||||||
request.stream.unwrap_or_default()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_streaming_options(&self, request: &mut crate::apis::openai::ChatCompletionsRequest) {
|
|
||||||
match self {
|
|
||||||
Provider::OpenAI(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
|
||||||
Provider::Groq(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
|
||||||
Provider::Mistral(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
|
||||||
Provider::Deepseek(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
|
||||||
Provider::Arch(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
|
||||||
Provider::Gemini(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
|
||||||
Provider::Claude(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
|
||||||
Provider::GitHub(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_messages_text(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String {
|
|
||||||
match self {
|
|
||||||
Provider::OpenAI(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
|
||||||
Provider::Groq(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
|
||||||
Provider::Mistral(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
|
||||||
Provider::Deepseek(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
|
||||||
Provider::Arch(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
|
||||||
Provider::Gemini(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
|
||||||
Provider::Claude(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
|
||||||
Provider::GitHub(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_user_message(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> Option<String> {
|
|
||||||
match self {
|
|
||||||
Provider::OpenAI(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
|
||||||
Provider::Groq(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
|
||||||
Provider::Mistral(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
|
||||||
Provider::Deepseek(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
|
||||||
Provider::Arch(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
|
||||||
Provider::Gemini(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
|
||||||
Provider::Claude(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
|
||||||
Provider::GitHub(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderResponse for Provider {
|
|
||||||
type Error = openai::provider::OpenAIApiError;
|
|
||||||
type Usage = crate::apis::openai::Usage;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider_id: &super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Self::Error> {
|
|
||||||
match self {
|
|
||||||
Provider::OpenAI(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::Groq(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::Mistral(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::Deepseek(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::Arch(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::Gemini(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::Claude(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::GitHub(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn usage<'a>(&self, response: &'a crate::apis::openai::ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
|
||||||
// Since all providers use the same implementation, just use the direct implementation
|
|
||||||
Some(&response.usage)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamingResponse for Provider {
|
|
||||||
type Error = openai::provider::OpenAIApiError;
|
|
||||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
|
||||||
type StreamingIter = openai::provider::OpenAIStreamingResponse;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider_id: &super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
|
||||||
match self {
|
|
||||||
Provider::OpenAI(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::Groq(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::Mistral(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::Deepseek(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::Arch(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::Gemini(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::Claude(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
Provider::GitHub(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderInterface for Provider {
|
|
||||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
|
||||||
match self {
|
|
||||||
Provider::OpenAI(provider, _) => provider.has_compatible_api(api_path),
|
|
||||||
Provider::Groq(provider, _) => provider.has_compatible_api(api_path),
|
|
||||||
Provider::Mistral(provider, _) => provider.has_compatible_api(api_path),
|
|
||||||
Provider::Deepseek(provider, _) => provider.has_compatible_api(api_path),
|
|
||||||
Provider::Arch(provider, _) => provider.has_compatible_api(api_path),
|
|
||||||
Provider::Gemini(provider, _) => provider.has_compatible_api(api_path),
|
|
||||||
Provider::Claude(provider, _) => provider.has_compatible_api(api_path),
|
|
||||||
Provider::GitHub(provider, _) => provider.has_compatible_api(api_path),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supported_apis(&self) -> Vec<&'static str> {
|
|
||||||
match self {
|
|
||||||
Provider::OpenAI(provider, _) => provider.supported_apis(),
|
|
||||||
Provider::Groq(provider, _) => provider.supported_apis(),
|
|
||||||
Provider::Mistral(provider, _) => provider.supported_apis(),
|
|
||||||
Provider::Deepseek(provider, _) => provider.supported_apis(),
|
|
||||||
Provider::Arch(provider, _) => provider.supported_apis(),
|
|
||||||
Provider::Gemini(provider, _) => provider.supported_apis(),
|
|
||||||
Provider::Claude(provider, _) => provider.supported_apis(),
|
|
||||||
Provider::GitHub(provider, _) => provider.supported_apis(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,10 @@
|
||||||
pub mod builder;
|
pub mod builder;
|
||||||
pub mod types;
|
pub mod types;
|
||||||
pub mod provider;
|
|
||||||
|
|
||||||
// Re-export the main provider
|
// Re-export the main types and builder functionality
|
||||||
pub use provider::OpenAIProvider;
|
pub use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse};
|
||||||
|
pub use builder::*;
|
||||||
|
pub use types::*;
|
||||||
|
|
||||||
|
// Note: The OpenAIProvider struct has been deprecated in favor of the function-based approach in traits.rs
|
||||||
|
// All provider functionality is now accessed through try_request_from_bytes, try_response_from_bytes, etc.
|
||||||
|
|
|
||||||
|
|
@ -1,217 +0,0 @@
|
||||||
//! OpenAI provider interface implementations
|
|
||||||
|
|
||||||
use crate::apis::openai::*;
|
|
||||||
use crate::providers::traits::*;
|
|
||||||
|
|
||||||
// Simple error type for OpenAI API operations
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
|
||||||
pub enum OpenAIApiError {
|
|
||||||
#[error("JSON parsing error: {0}")]
|
|
||||||
JsonError(#[from] serde_json::Error),
|
|
||||||
#[error("UTF-8 parsing error: {0}")]
|
|
||||||
Utf8Error(#[from] std::str::Utf8Error),
|
|
||||||
#[error("Invalid streaming data: {0}")]
|
|
||||||
InvalidStreamingData(String),
|
|
||||||
#[error("Request conversion error: {0}")]
|
|
||||||
RequestConversionError(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// OpenAI Provider Definition
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
pub struct OpenAIProvider;
|
|
||||||
|
|
||||||
// Create a concrete streaming response type to avoid lifetime issues
|
|
||||||
pub struct OpenAIStreamingResponse {
|
|
||||||
lines: Vec<String>,
|
|
||||||
current_index: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAIStreamingResponse {
|
|
||||||
fn new(data: String) -> Self {
|
|
||||||
let lines: Vec<String> = data.lines().map(|s| s.to_string()).collect();
|
|
||||||
Self {
|
|
||||||
lines,
|
|
||||||
current_index: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Iterator for OpenAIStreamingResponse {
|
|
||||||
type Item = Result<ChatCompletionsStreamResponse, OpenAIApiError>;
|
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
|
||||||
while self.current_index < self.lines.len() {
|
|
||||||
let line = &self.lines[self.current_index];
|
|
||||||
self.current_index += 1;
|
|
||||||
|
|
||||||
if let Some(data) = line.strip_prefix("data: ") {
|
|
||||||
let data = data.trim();
|
|
||||||
if data == "[DONE]" {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
if data == r#"{"type": "ping"}"# {
|
|
||||||
continue; // Skip ping messages
|
|
||||||
}
|
|
||||||
|
|
||||||
return Some(
|
|
||||||
serde_json::from_str::<ChatCompletionsStreamResponse>(data).map_err(|e| {
|
|
||||||
OpenAIApiError::InvalidStreamingData(format!("Error parsing: {}, data: {}", e, data))
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderInterface for OpenAIProvider {
|
|
||||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
|
||||||
api_path == "/v1/chat/completions"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supported_apis(&self) -> Vec<&'static str> {
|
|
||||||
vec!["/v1/chat/completions"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Direct trait implementations on OpenAIProvider
|
|
||||||
impl ProviderRequest for OpenAIProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
|
||||||
let s = std::str::from_utf8(bytes)?;
|
|
||||||
Ok(serde_json::from_str(s)?)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, _provider: super::super::ProviderId, _mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
|
||||||
Ok(serde_json::to_vec(request)?)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
|
||||||
&request.model
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
|
||||||
request.stream.unwrap_or_default()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
|
||||||
if request.stream_options.is_none() {
|
|
||||||
request.stream_options = Some(StreamOptions {
|
|
||||||
include_usage: Some(true),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
|
||||||
request.messages
|
|
||||||
.iter()
|
|
||||||
.fold(String::new(), |acc, m| {
|
|
||||||
acc + " " + &match &m.content {
|
|
||||||
MessageContent::Text(text) => text.clone(),
|
|
||||||
MessageContent::Parts(parts) => {
|
|
||||||
parts.iter().map(|part| match part {
|
|
||||||
ContentPart::Text { text } => text.clone(),
|
|
||||||
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
|
|
||||||
}).collect::<Vec<_>>().join(" ")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
|
||||||
request.messages.last().and_then(|msg| {
|
|
||||||
match &msg.content {
|
|
||||||
MessageContent::Text(text) => Some(text.clone()),
|
|
||||||
MessageContent::Parts(parts) => {
|
|
||||||
// Extract text from content parts, ignoring images
|
|
||||||
let text_parts: Vec<String> = parts
|
|
||||||
.iter()
|
|
||||||
.filter_map(|part| match part {
|
|
||||||
ContentPart::Text { text } => Some(text.clone()),
|
|
||||||
ContentPart::ImageUrl { .. } => None,
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
if text_parts.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(text_parts.join(" "))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderResponse for OpenAIProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type Usage = Usage;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
|
||||||
let s = std::str::from_utf8(bytes)?;
|
|
||||||
Ok(serde_json::from_str(s)?)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
|
||||||
Some(&response.usage)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
|
||||||
Some((
|
|
||||||
response.usage.prompt_tokens as usize,
|
|
||||||
response.usage.completion_tokens as usize,
|
|
||||||
response.usage.total_tokens as usize,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamingResponse for OpenAIProvider {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type StreamChunk = ChatCompletionsStreamResponse;
|
|
||||||
type StreamingIter = OpenAIStreamingResponse;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
|
||||||
let s = std::str::from_utf8(bytes)?;
|
|
||||||
Ok(OpenAIStreamingResponse::new(s.to_string()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// Trait Implementations for OpenAI Types (Keep for TokenUsage only)
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
impl TokenUsage for Usage {
|
|
||||||
fn completion_tokens(&self) -> usize {
|
|
||||||
self.completion_tokens as usize
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prompt_tokens(&self) -> usize {
|
|
||||||
self.prompt_tokens as usize
|
|
||||||
}
|
|
||||||
|
|
||||||
fn total_tokens(&self) -> usize {
|
|
||||||
self.total_tokens as usize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamChunk for ChatCompletionsStreamResponse {
|
|
||||||
type Usage = Usage;
|
|
||||||
|
|
||||||
fn usage(&self) -> Option<&Self::Usage> {
|
|
||||||
self.usage.as_ref()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamingResponse for OpenAIStreamingResponse {
|
|
||||||
type Error = OpenAIApiError;
|
|
||||||
type StreamChunk = ChatCompletionsStreamResponse;
|
|
||||||
type StreamingIter = OpenAIStreamingResponse;
|
|
||||||
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
|
||||||
let s = std::str::from_utf8(bytes)?;
|
|
||||||
Ok(OpenAIStreamingResponse::new(s.to_string()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
//! handling of LLM requests and responses in the gateway.
|
//! handling of LLM requests and responses in the gateway.
|
||||||
|
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
/// Conversion mode for provider requests/responses
|
/// Conversion mode for provider requests/responses
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
|
@ -14,30 +15,41 @@ pub enum ConversionMode {
|
||||||
Passthrough,
|
Passthrough,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for provider-specific request types
|
/// Error types for provider operations
|
||||||
pub trait ProviderRequest {
|
#[derive(Debug)]
|
||||||
type Error: Error + Send + Sync + 'static;
|
pub struct ProviderRequestError {
|
||||||
|
pub message: String,
|
||||||
|
pub source: Option<Box<dyn Error + Send + Sync>>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Parse request from raw bytes
|
#[derive(Debug)]
|
||||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Self::Error>;
|
pub struct ProviderResponseError {
|
||||||
|
pub message: String,
|
||||||
|
pub source: Option<Box<dyn Error + Send + Sync>>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Convert to provider-specific format
|
impl fmt::Display for ProviderRequestError {
|
||||||
fn to_provider_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error>;
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "Provider request error: {}", self.message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Extract the model name from the request
|
impl fmt::Display for ProviderResponseError {
|
||||||
fn extract_model<'a>(&self, request: &'a crate::apis::openai::ChatCompletionsRequest) -> &'a str;
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "Provider response error: {}", self.message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Check if this is a streaming request
|
impl Error for ProviderRequestError {
|
||||||
fn is_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool;
|
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||||
|
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Set streaming options (e.g., include_usage)
|
impl Error for ProviderResponseError {
|
||||||
fn set_streaming_options(&self, request: &mut crate::apis::openai::ChatCompletionsRequest);
|
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||||
|
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||||
/// Extract text content from messages for token counting
|
}
|
||||||
fn extract_messages_text(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String;
|
|
||||||
|
|
||||||
/// Extract the user message for tracing/logging purposes
|
|
||||||
fn extract_user_message(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> Option<String>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for token usage information
|
/// Trait for token usage information
|
||||||
|
|
@ -47,46 +59,178 @@ pub trait TokenUsage {
|
||||||
fn total_tokens(&self) -> usize;
|
fn total_tokens(&self) -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Trait for provider-specific request types
|
||||||
|
pub trait ProviderRequest: Send + Sync {
|
||||||
|
/// Extract the model name from the request
|
||||||
|
fn model(&self) -> &str;
|
||||||
|
|
||||||
|
/// Check if this is a streaming request
|
||||||
|
fn is_streaming(&self) -> bool;
|
||||||
|
|
||||||
|
/// Set streaming options (e.g., include_usage)
|
||||||
|
fn set_streaming_options(&mut self);
|
||||||
|
|
||||||
|
/// Extract text content from messages for token counting
|
||||||
|
fn extract_messages_text(&self) -> String;
|
||||||
|
|
||||||
|
/// Extract the user message for tracing/logging purposes
|
||||||
|
fn extract_user_message(&self) -> Option<String>;
|
||||||
|
|
||||||
|
/// Convert to provider-specific format
|
||||||
|
fn to_provider_bytes(&self, mode: ConversionMode) -> Result<Vec<u8>, ProviderRequestError>;
|
||||||
|
}
|
||||||
|
|
||||||
/// Trait for provider-specific response types
|
/// Trait for provider-specific response types
|
||||||
pub trait ProviderResponse {
|
pub trait ProviderResponse: Send + Sync {
|
||||||
type Error: Error + Send + Sync + 'static;
|
/// Get usage information if available - returns dynamic trait object
|
||||||
type Usage: TokenUsage;
|
fn usage(&self) -> Option<&dyn TokenUsage>;
|
||||||
|
|
||||||
/// Parse response from raw bytes
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Self::Error>;
|
|
||||||
|
|
||||||
/// Get usage information if available
|
|
||||||
fn usage<'a>(&self, response: &'a crate::apis::openai::ChatCompletionsResponse) -> Option<&'a Self::Usage>;
|
|
||||||
|
|
||||||
/// Extract token counts for metrics
|
/// Extract token counts for metrics
|
||||||
fn extract_usage_counts(&self, response: &crate::apis::openai::ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||||
self.usage(response).map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for streaming response chunks
|
/// Trait for provider-specific streaming response types
|
||||||
pub trait StreamChunk {
|
pub trait ProviderStreamResponse: Send + Sync {
|
||||||
type Usage: TokenUsage;
|
/// Get the content delta for this chunk
|
||||||
|
fn content_delta(&self) -> Option<&str>;
|
||||||
|
|
||||||
/// Get usage information if available
|
/// Check if this is the final chunk in the stream
|
||||||
fn usage(&self) -> Option<&Self::Usage>;
|
fn is_final(&self) -> bool;
|
||||||
|
|
||||||
|
/// Get role information if available
|
||||||
|
fn role(&self) -> Option<&str>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for streaming response iterators
|
/// Trait for streaming response iterators
|
||||||
pub trait StreamingResponse {
|
///
|
||||||
type Error: Error + Send + Sync + 'static;
|
/// This trait ensures that implementing types are iterators that yield
|
||||||
type StreamChunk: StreamChunk;
|
/// ProviderStreamResponse results.
|
||||||
type StreamingIter: Iterator<Item = Result<Self::StreamChunk, Self::Error>>;
|
pub trait ProviderStreamResponseIter: Iterator<Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>> + Send + Sync {
|
||||||
|
// No additional methods needed - just the Iterator constraint with proper bounds
|
||||||
/// Parse streaming response from raw bytes
|
|
||||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Main provider interface trait - simplified to only essential methods
|
// ============================================================================
|
||||||
pub trait ProviderInterface: ProviderRequest + ProviderResponse + StreamingResponse {
|
// PROVIDER FUNCTIONS - NO TRAITS, JUST PARAMETERIZED CONVERSION
|
||||||
/// Check if this provider has a compatible API with the client request
|
// ============================================================================
|
||||||
fn has_compatible_api(&self, api_path: &str) -> bool;
|
//
|
||||||
|
// ARCHITECTURAL DECISION: Function-based Provider API
|
||||||
|
//
|
||||||
|
// We chose this function-based approach over the original ProviderInterface trait
|
||||||
|
// for several critical reasons:
|
||||||
|
//
|
||||||
|
// 1. TRAIT OBJECT LIMITATION:
|
||||||
|
// - The original ProviderInterface had associated types (Request, Response, etc.)
|
||||||
|
// - Traits with associated types cannot be used as trait objects (Box<dyn ProviderInterface>)
|
||||||
|
// - This prevented dynamic provider selection at runtime based on request headers
|
||||||
|
// - Error: "the trait `ProviderInterface` cannot be made into an object"
|
||||||
|
//
|
||||||
|
// 2. DYNAMIC PROVIDER SELECTION REQUIREMENT:
|
||||||
|
// - The gateway needs to select providers dynamically based on incoming headers
|
||||||
|
// - Cannot know provider type at compile time - must dispatch at runtime
|
||||||
|
// - Need ability to return generic trait objects that work polymorphically
|
||||||
|
//
|
||||||
|
// 3. WRAPPER TYPE ELIMINATION:
|
||||||
|
// - Original design required wrapper types like OpenAIRequestWrapper, OpenAIResponseWrapper
|
||||||
|
// - User wanted to implement traits directly on concrete types (ChatCompletionsRequest, etc.)
|
||||||
|
// - Function-based approach allows direct trait implementations without wrappers
|
||||||
|
//
|
||||||
|
// 4. PARAMETERIZED CONVERSION PATTERN:
|
||||||
|
// - Follows existing codebase pattern: TryFrom<(&[u8], &ProviderId)>
|
||||||
|
// - Enables runtime provider selection while maintaining type safety
|
||||||
|
// - Single implementation can handle multiple OpenAI-compatible providers
|
||||||
|
//
|
||||||
|
// 5. TYPE ERASURE FOR GENERIC INTERFACE:
|
||||||
|
// - Functions return Box<dyn ProviderRequest/Response> - works as trait objects
|
||||||
|
// - stream_context.rs can work with generic interfaces without knowing concrete types
|
||||||
|
// - Maintains polymorphism while enabling dynamic dispatch
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
/// Get supported API endpoints for this provider
|
use crate::ProviderId;
|
||||||
fn supported_apis(&self) -> Vec<&'static str>;
|
|
||||||
|
/// Parse request from bytes using provider ID - returns generic ProviderRequest trait object
|
||||||
|
pub fn try_request_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result<Box<dyn ProviderRequest>, ProviderRequestError> {
|
||||||
|
match provider_id {
|
||||||
|
// All these providers currently use OpenAI-compatible chat completions API
|
||||||
|
// In the future, we can add provider-specific handling in separate match arms
|
||||||
|
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||||
|
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||||
|
|
||||||
|
let request = crate::apis::openai::ChatCompletionsRequest::try_from((bytes, provider_id))
|
||||||
|
.map_err(|e| ProviderRequestError {
|
||||||
|
message: format!("Failed to parse request: {}", e),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Return as trait object - this enables polymorphic usage
|
||||||
|
// ChatCompletionsRequest implements ProviderRequest directly (no wrapper needed)
|
||||||
|
Ok(Box::new(request) as Box<dyn ProviderRequest>)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse response from bytes using provider ID - returns generic ProviderResponse trait object
|
||||||
|
pub fn try_response_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result<Box<dyn ProviderResponse>, ProviderResponseError> {
|
||||||
|
match provider_id {
|
||||||
|
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||||
|
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||||
|
// Parameterized conversion allows provider-specific response parsing
|
||||||
|
let response = crate::apis::openai::ChatCompletionsResponse::try_from((bytes, provider_id))
|
||||||
|
.map_err(|e| ProviderResponseError {
|
||||||
|
message: format!("Failed to parse response: {}", e),
|
||||||
|
source: Some(Box::new(e)),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// ChatCompletionsResponse implements ProviderResponse directly - no wrapper needed!
|
||||||
|
Ok(Box::new(response) as Box<dyn ProviderResponse>)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create streaming response using provider ID - returns clean ProviderStreamResponseIter trait object
|
||||||
|
///
|
||||||
|
/// This function returns a ProviderStreamResponseIter that's just an iterator,
|
||||||
|
/// eliminating the complex nested Result<Box<dyn Iterator<...>>> type completely.
|
||||||
|
pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result<Box<dyn ProviderStreamResponseIter>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
match provider_id {
|
||||||
|
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||||
|
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||||
|
// Parse SSE (Server-Sent Events) streaming data
|
||||||
|
let s = std::str::from_utf8(bytes)?;
|
||||||
|
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||||
|
let iter = crate::apis::openai::SseChatCompletionIter::new(lines.into_iter());
|
||||||
|
|
||||||
|
// Return the iterator directly - it implements ProviderStreamResponseIter
|
||||||
|
Ok(Box::new(iter))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if provider has compatible API
|
||||||
|
///
|
||||||
|
/// Replaces the old ProviderInterface::has_compatible_api method.
|
||||||
|
/// This function enables runtime API compatibility checking without needing a provider instance.
|
||||||
|
pub fn has_compatible_api(provider_id: &ProviderId, api_path: &str) -> bool {
|
||||||
|
match provider_id {
|
||||||
|
// Currently all these providers support OpenAI chat completions API
|
||||||
|
// Future providers with different APIs will get their own match arms
|
||||||
|
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||||
|
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||||
|
api_path == "/v1/chat/completions"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get supported APIs for provider
|
||||||
|
///
|
||||||
|
/// Replaces the old ProviderInterface::supported_apis method.
|
||||||
|
/// Returns a static list of supported API endpoints for the given provider.
|
||||||
|
pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> {
|
||||||
|
match provider_id {
|
||||||
|
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||||
|
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||||
|
vec!["/v1/chat/completions"]
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,10 @@ use common::ratelimit::Header;
|
||||||
use common::stats::{IncrementingMetric, RecordingMetric};
|
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||||
use common::{ratelimit, routing, tokenizer};
|
use common::{ratelimit, routing, tokenizer};
|
||||||
use hermesllm::providers::traits::{
|
use hermesllm::{
|
||||||
ProviderRequest, ProviderResponse, StreamChunk, StreamingResponse, TokenUsage,
|
try_request_from_bytes, try_response_from_bytes, try_streaming_from_bytes, ConversionMode,
|
||||||
|
ProviderId,
|
||||||
};
|
};
|
||||||
use hermesllm::{ConversionMode, Provider, ProviderId};
|
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use log::{debug, info, warn};
|
use log::{debug, info, warn};
|
||||||
use proxy_wasm::hostcalls::get_current_time;
|
use proxy_wasm::hostcalls::get_current_time;
|
||||||
|
|
@ -79,8 +79,8 @@ impl StreamContext {
|
||||||
.expect("the provider should be set when asked for it")
|
.expect("the provider should be set when asked for it")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_provider(&self) -> Provider {
|
fn get_provider_id(&self) -> ProviderId {
|
||||||
self.llm_provider().create_provider()
|
self.llm_provider().to_provider_id()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn select_llm_provider(&mut self) {
|
fn select_llm_provider(&mut self) {
|
||||||
|
|
@ -298,9 +298,9 @@ impl HttpContext for StreamContext {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let provider = self.get_provider();
|
let provider_id = self.get_provider_id();
|
||||||
|
|
||||||
let mut deserialized_body = match ProviderRequest::try_from_bytes(&provider, &body_bytes) {
|
let mut deserialized_body = match try_request_from_bytes(&body_bytes, &provider_id) {
|
||||||
Ok(deserialized) => deserialized,
|
Ok(deserialized) => deserialized,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!(
|
debug!(
|
||||||
|
|
@ -329,10 +329,10 @@ impl HttpContext for StreamContext {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Use the provider interface methods for cleaner interaction
|
// Use the provider interface methods for cleaner interaction
|
||||||
let model_requested = provider.extract_model(&deserialized_body).to_string(); // Convert to owned string
|
let model_requested = deserialized_body.model().to_string(); // Convert to owned string
|
||||||
|
|
||||||
// Extract user message for tracing
|
// Extract user message for tracing
|
||||||
self.user_message = provider.extract_user_message(&deserialized_body);
|
self.user_message = deserialized_body.extract_user_message();
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||||
|
|
@ -342,15 +342,15 @@ impl HttpContext for StreamContext {
|
||||||
);
|
);
|
||||||
|
|
||||||
// Use provider interface for streaming detection and setup
|
// Use provider interface for streaming detection and setup
|
||||||
self.streaming_response = provider.is_streaming(&deserialized_body);
|
self.streaming_response = deserialized_body.is_streaming();
|
||||||
|
|
||||||
// Set streaming options if needed
|
// Set streaming options if needed
|
||||||
if self.streaming_response {
|
if self.streaming_response {
|
||||||
provider.set_streaming_options(&mut deserialized_body);
|
deserialized_body.set_streaming_options();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use provider interface for text extraction (after potential mutation)
|
// Use provider interface for text extraction (after potential mutation)
|
||||||
let input_tokens_str = provider.extract_messages_text(&deserialized_body);
|
let input_tokens_str = deserialized_body.extract_messages_text();
|
||||||
// enforce ratelimits on ingress
|
// enforce ratelimits on ingress
|
||||||
if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) {
|
if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) {
|
||||||
self.send_server_error(
|
self.send_server_error(
|
||||||
|
|
@ -365,21 +365,18 @@ impl HttpContext for StreamContext {
|
||||||
let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
|
let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||||
|
|
||||||
// Convert chat completion request to llm provider specific request using provider interface
|
// Convert chat completion request to llm provider specific request using provider interface
|
||||||
let deserialized_body_bytes = match provider.to_provider_bytes(
|
let deserialized_body_bytes =
|
||||||
&deserialized_body,
|
match deserialized_body.to_provider_bytes(ConversionMode::Compatible) {
|
||||||
provider.id(),
|
Ok(bytes) => bytes,
|
||||||
ConversionMode::Compatible,
|
Err(e) => {
|
||||||
) {
|
warn!("Failed to serialize request body: {}", e);
|
||||||
Ok(bytes) => bytes,
|
self.send_server_error(
|
||||||
Err(e) => {
|
ServerError::LogicError(format!("Request serialization error: {}", e)),
|
||||||
warn!("Failed to serialize request body: {}", e);
|
Some(StatusCode::BAD_REQUEST),
|
||||||
self.send_server_error(
|
);
|
||||||
ServerError::LogicError(format!("Request serialization error: {}", e)),
|
return Action::Pause;
|
||||||
Some(StatusCode::BAD_REQUEST),
|
}
|
||||||
);
|
};
|
||||||
return Action::Pause;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
self.set_http_request_body(0, body_size, &deserialized_body_bytes);
|
self.set_http_request_body(0, body_size, &deserialized_body_bytes);
|
||||||
|
|
||||||
|
|
@ -550,16 +547,9 @@ impl HttpContext for StreamContext {
|
||||||
|
|
||||||
// Parse streaming response using OpenAI-compatible format
|
// Parse streaming response using OpenAI-compatible format
|
||||||
// Since all providers use OpenAI-compatible streaming format
|
// Since all providers use OpenAI-compatible streaming format
|
||||||
let provider = self.get_provider();
|
let provider_id = self.get_provider_id();
|
||||||
let provider_id =
|
|
||||||
ProviderId::from(self.llm_provider().provider_interface.to_string().as_str());
|
|
||||||
|
|
||||||
match StreamingResponse::try_from_bytes(
|
match try_streaming_from_bytes(&body, &provider_id, ConversionMode::Compatible) {
|
||||||
&provider,
|
|
||||||
&body,
|
|
||||||
&provider_id,
|
|
||||||
ConversionMode::Compatible,
|
|
||||||
) {
|
|
||||||
Ok(mut streaming_response) => {
|
Ok(mut streaming_response) => {
|
||||||
// Process each streaming chunk
|
// Process each streaming chunk
|
||||||
while let Some(chunk_result) = streaming_response.next() {
|
while let Some(chunk_result) = streaming_response.next() {
|
||||||
|
|
@ -587,14 +577,20 @@ impl HttpContext for StreamContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract usage information if available
|
// For streaming responses, we handle token counting differently
|
||||||
if let Some(usage) = chunk.usage() {
|
// The ProviderStreamResponse trait provides content_delta, is_final, and role
|
||||||
let completion_tokens = usage.completion_tokens();
|
// Token counting for streaming responses typically happens with final usage chunk
|
||||||
self.response_tokens += completion_tokens;
|
if chunk.is_final() {
|
||||||
debug!(
|
// For now, we'll implement basic token estimation
|
||||||
"Streaming chunk completion tokens: {}",
|
// In a complete implementation, the final chunk would contain usage information
|
||||||
completion_tokens
|
debug!("Received final streaming chunk");
|
||||||
);
|
}
|
||||||
|
|
||||||
|
// For now, estimate tokens from content delta
|
||||||
|
if let Some(content) = chunk.content_delta() {
|
||||||
|
// Rough estimation: ~4 characters per token
|
||||||
|
let estimated_tokens = content.len() / 4;
|
||||||
|
self.response_tokens += estimated_tokens.max(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
|
@ -605,40 +601,37 @@ impl HttpContext for StreamContext {
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Failed to parse streaming response: {}", e);
|
warn!("Failed to parse streaming response: {}", e);
|
||||||
|
return Action::Continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
debug!("non streaming response");
|
debug!("non streaming response");
|
||||||
let provider = self.get_provider();
|
let provider_id = self.get_provider_id();
|
||||||
let response = match ProviderResponse::try_from_bytes(
|
let response =
|
||||||
&provider,
|
match try_response_from_bytes(&body, &provider_id, ConversionMode::Compatible) {
|
||||||
&body,
|
Ok(response) => response,
|
||||||
&provider.id(),
|
Err(e) => {
|
||||||
ConversionMode::Compatible,
|
warn!(
|
||||||
) {
|
"could not parse response: {}, body str: {}",
|
||||||
Ok(response) => response,
|
e,
|
||||||
Err(e) => {
|
String::from_utf8_lossy(&body)
|
||||||
warn!(
|
);
|
||||||
"could not parse response: {}, body str: {}",
|
debug!(
|
||||||
e,
|
"on_http_response_body: S[{}], response body: {}",
|
||||||
String::from_utf8_lossy(&body)
|
self.context_id,
|
||||||
);
|
String::from_utf8_lossy(&body)
|
||||||
debug!(
|
);
|
||||||
"on_http_response_body: S[{}], response body: {}",
|
self.send_server_error(
|
||||||
self.context_id,
|
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||||
String::from_utf8_lossy(&body)
|
Some(StatusCode::BAD_REQUEST),
|
||||||
);
|
);
|
||||||
self.send_server_error(
|
return Action::Continue;
|
||||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
}
|
||||||
Some(StatusCode::BAD_REQUEST),
|
};
|
||||||
);
|
|
||||||
return Action::Continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Use provider interface to extract usage information
|
// Use provider interface to extract usage information
|
||||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
||||||
provider.extract_usage_counts(&response)
|
response.extract_usage_counts()
|
||||||
{
|
{
|
||||||
debug!(
|
debug!(
|
||||||
"Response usage: prompt={}, completion={}, total={}",
|
"Response usage: prompt={}, completion={}, total={}",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue