mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
more refactoring to clean code and make stream_context.rs work
This commit is contained in:
parent
d4ca70d177
commit
df3aa17d67
23 changed files with 545 additions and 1321 deletions
|
|
@ -178,13 +178,10 @@ impl Display for LlmProviderType {
|
|||
}
|
||||
|
||||
impl LlmProviderType {
|
||||
/// Create a Provider from this LlmProviderType
|
||||
/// This is the main method for stream_context to get provider-specific interfaces
|
||||
pub fn create_provider(&self) -> hermesllm::Provider {
|
||||
use hermesllm::{ProviderId, Provider};
|
||||
|
||||
let provider_id = ProviderId::from(self.to_string().as_str());
|
||||
Provider::new(provider_id)
|
||||
/// Get the ProviderId for this LlmProviderType
|
||||
/// Used with the new function-based hermesllm API
|
||||
pub fn to_provider_id(&self) -> hermesllm::ProviderId {
|
||||
hermesllm::ProviderId::from(self.to_string().as_str())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -264,10 +261,10 @@ impl Display for LlmProvider {
|
|||
}
|
||||
|
||||
impl LlmProvider {
|
||||
/// Create a Provider for this LlmProvider
|
||||
/// This is a convenience method that delegates to the provider_interface
|
||||
pub fn create_provider(&self) -> hermesllm::Provider {
|
||||
self.provider_interface.create_provider()
|
||||
/// Get the ProviderId for this LlmProvider
|
||||
/// Used with the new function-based hermesllm API
|
||||
pub fn to_provider_id(&self) -> hermesllm::ProviderId {
|
||||
self.provider_interface.to_provider_id()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -35,15 +35,17 @@ _Replace the path with the appropriate location if using as a workspace member o
|
|||
Construct a chat completion request using the builder pattern:
|
||||
|
||||
```rust
|
||||
use hermesllm::Provider;
|
||||
use hermesllm::{create_provider, ProviderId};
|
||||
use hermesllm::providers::openai::types::ChatCompletionsRequest;
|
||||
|
||||
let request = ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![Message::new("Hi".to_string())])
|
||||
.build()
|
||||
.expect("Failed to build OpenAIRequest");
|
||||
|
||||
// Convert to bytes for a specific provider
|
||||
let bytes = request.to_bytes(Provider::OpenAI)?;
|
||||
// Create a provider and convert request to bytes
|
||||
let provider = create_provider(ProviderId::OpenAI);
|
||||
let bytes = serde_json::to_vec(&request)?;
|
||||
let parsed_request = provider.try_request_from_bytes(&bytes)?;
|
||||
```
|
||||
|
||||
## API Overview
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ use serde_json::Value;
|
|||
use serde_with::skip_serializing_none;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{providers::ProviderRequestError, ConversionMode, ProviderRequest};
|
||||
use super::ApiDefinition;
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -424,6 +425,212 @@ pub struct StreamOptions {
|
|||
pub include_usage: Option<bool>,
|
||||
}
|
||||
|
||||
/// ============================================================================
|
||||
/// OpenAI Provider Request Wrapper
|
||||
/// ============================================================================
|
||||
impl ProviderRequest for ChatCompletionsRequest {
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self) -> bool {
|
||||
self.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&mut self) {
|
||||
self.stream = Some(true);
|
||||
if self.stream_options.is_none() {
|
||||
self.stream_options = Some(StreamOptions { include_usage: Some(true) });
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
self.messages.iter().fold(String::new(), |acc, m| {
|
||||
acc + " " + &match &m.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts.iter().map(|part| match part {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
|
||||
}).collect::<Vec<_>>().join(" ")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_user_message(&self) -> Option<String> {
|
||||
self.messages.last().and_then(|msg| {
|
||||
match &msg.content {
|
||||
MessageContent::Text(text) => Some(text.clone()),
|
||||
MessageContent::Parts(_) => None, // No user message in parts
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, mode: ConversionMode) -> Result<Vec<u8>, ProviderRequestError> {
|
||||
match mode {
|
||||
ConversionMode::Compatible | ConversionMode::Passthrough => {
|
||||
serde_json::to_vec(&self).map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to serialize OpenAI request: {}", e),
|
||||
source: Some(Box::new(e)),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// STREAMING SUPPORT
|
||||
// ============================================================================
|
||||
|
||||
use crate::providers::traits::{ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, TokenUsage};
|
||||
|
||||
// Direct implementation of ProviderResponse on ChatCompletionsResponse
|
||||
impl ProviderResponse for ChatCompletionsResponse {
|
||||
fn usage(&self) -> Option<&dyn TokenUsage> {
|
||||
Some(&self.usage)
|
||||
}
|
||||
}
|
||||
|
||||
// Implementation of TokenUsage for OpenAI Usage type
|
||||
impl TokenUsage for Usage {
|
||||
fn completion_tokens(&self) -> usize {
|
||||
self.completion_tokens as usize
|
||||
}
|
||||
|
||||
fn prompt_tokens(&self) -> usize {
|
||||
self.prompt_tokens as usize
|
||||
}
|
||||
|
||||
fn total_tokens(&self) -> usize {
|
||||
self.total_tokens as usize
|
||||
}
|
||||
}
|
||||
|
||||
// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse
|
||||
impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
self.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.content.as_deref())
|
||||
}
|
||||
|
||||
fn is_final(&self) -> bool {
|
||||
self.choices
|
||||
.first()
|
||||
.map(|choice| choice.finish_reason.is_some())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn role(&self) -> Option<&str> {
|
||||
self.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.role.as_ref().map(|r| match r {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
// Error type for streaming operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OpenAIStreamError {
|
||||
#[error("JSON parsing error: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
#[error("UTF-8 parsing error: {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
#[error("Invalid streaming data: {0}")]
|
||||
InvalidStreamingData(String),
|
||||
}
|
||||
|
||||
/// SSE-based streaming iterator for OpenAI chat completions
|
||||
/// Implements ProviderStreamResponseIter directly
|
||||
pub struct SseChatCompletionIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
lines: I,
|
||||
}
|
||||
|
||||
impl<I> SseChatCompletionIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub fn new(lines: I) -> Self {
|
||||
Self { lines }
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Iterator for SseChatCompletionIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
for line in &mut self.lines {
|
||||
let line = line.as_ref();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with("data: ") {
|
||||
let data = &line[6..]; // Remove "data: " prefix
|
||||
if data == "[DONE]" {
|
||||
return None;
|
||||
}
|
||||
|
||||
if data == r#"{"type": "ping"}"# {
|
||||
continue; // Skip ping messages - that is usually from anthropic
|
||||
}
|
||||
|
||||
match serde_json::from_str::<ChatCompletionsStreamResponse>(data) {
|
||||
Ok(response) => return Some(Ok(Box::new(response))),
|
||||
Err(e) => return Some(Err(Box::new(
|
||||
OpenAIStreamError::InvalidStreamingData(format!("Error parsing: {}, data: {}", e, data))
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> ProviderStreamResponseIter for SseChatCompletionIter<I>
|
||||
where
|
||||
I: Iterator + Send + Sync,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
// Just marking that this type implements the trait - no additional methods needed
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// PARAMETERIZED CONVERSIONS FOR PROVIDER FUNCTIONS
|
||||
// ============================================================================
|
||||
|
||||
use crate::providers::ProviderId;
|
||||
|
||||
/// Parameterized conversion for ChatCompletionsRequest
|
||||
impl TryFrom<(&[u8], &ProviderId)> for ChatCompletionsRequest {
|
||||
type Error = OpenAIStreamError;
|
||||
|
||||
fn try_from((bytes, _provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameterized conversion for ChatCompletionsResponse
|
||||
impl TryFrom<(&[u8], &ProviderId)> for ChatCompletionsResponse {
|
||||
type Error = OpenAIStreamError;
|
||||
|
||||
fn try_from((bytes, _provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
|
|
@ -7,10 +7,11 @@ pub mod clients;
|
|||
|
||||
// Re-export important types and traits
|
||||
pub use providers::{
|
||||
ProviderId, Provider, ConversionMode,
|
||||
ProviderInterface, ProviderRequest, ProviderResponse,
|
||||
TokenUsage, StreamChunk, StreamingResponse,
|
||||
OpenAIProvider
|
||||
ProviderId, ConversionMode,
|
||||
ProviderRequest, ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter,
|
||||
TokenUsage,
|
||||
try_request_from_bytes, try_response_from_bytes, try_streaming_from_bytes,
|
||||
has_compatible_api, supported_apis
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -26,70 +27,71 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_api_paths() {
|
||||
assert_eq!(ProviderId::OpenAI.api_path(), "/v1/chat/completions");
|
||||
assert_eq!(ProviderId::Groq.api_path(), "/openai/v1/chat/completions");
|
||||
assert_eq!(ProviderId::Mistral.api_path(), "/v1/chat/completions");
|
||||
assert_eq!(ProviderId::Arch.api_path(), "/v1/chat/completions");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_openai_format_support() {
|
||||
assert!(ProviderId::OpenAI.supports_openai_format());
|
||||
assert!(ProviderId::Groq.supports_openai_format());
|
||||
assert!(ProviderId::Mistral.supports_openai_format());
|
||||
assert!(ProviderId::Arch.supports_openai_format());
|
||||
assert!(!ProviderId::Gemini.supports_openai_format());
|
||||
assert!(!ProviderId::Claude.supports_openai_format());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_instance_creation() {
|
||||
let provider = Provider::new(ProviderId::OpenAI);
|
||||
assert!(provider.has_compatible_api("/v1/chat/completions"));
|
||||
assert!(!provider.has_compatible_api("/v1/embeddings"));
|
||||
fn test_provider_api_compatibility() {
|
||||
assert!(has_compatible_api(&ProviderId::OpenAI, "/v1/chat/completions"));
|
||||
assert!(!has_compatible_api(&ProviderId::OpenAI, "/v1/embeddings"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_supported_apis() {
|
||||
let provider = Provider::new(ProviderId::OpenAI);
|
||||
|
||||
let supported_apis = provider.supported_apis();
|
||||
assert!(supported_apis.contains(&"/v1/chat/completions"));
|
||||
let apis = supported_apis(&ProviderId::OpenAI);
|
||||
assert!(apis.contains(&"/v1/chat/completions"));
|
||||
|
||||
// Test that provider supports the expected API endpoints
|
||||
assert!(provider.has_compatible_api("/v1/chat/completions"));
|
||||
assert!(has_compatible_api(&ProviderId::OpenAI, "/v1/chat/completions"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_extract_user_message() {
|
||||
use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent};
|
||||
|
||||
let provider = Provider::new(ProviderId::OpenAI);
|
||||
|
||||
// Test with text message
|
||||
let request = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: crate::apis::openai::Role::System,
|
||||
content: MessageContent::Text("You are a helpful assistant".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
fn test_provider_request_parsing() {
|
||||
// Test with a sample JSON request
|
||||
let json_request = r#"{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
},
|
||||
Message {
|
||||
role: crate::apis::openai::Role::User,
|
||||
content: MessageContent::Text("Hello, world!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let user_message = provider.extract_user_message(&request);
|
||||
assert_eq!(user_message, Some("Hello, world!".to_string()));
|
||||
let result = try_request_from_bytes(json_request.as_bytes(), &ProviderId::OpenAI);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let request = result.unwrap();
|
||||
assert_eq!(request.model(), "gpt-4");
|
||||
assert_eq!(request.extract_user_message(), Some("Hello!".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_streaming_response() {
|
||||
// Test streaming response parsing with sample SSE data
|
||||
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let result = try_streaming_from_bytes(sse_data.as_bytes(), &ProviderId::OpenAI, ConversionMode::Passthrough);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let mut streaming_response = result.unwrap();
|
||||
|
||||
// Test that we can iterate over chunks - it's just an iterator now!
|
||||
let first_chunk = streaming_response.next();
|
||||
assert!(first_chunk.is_some());
|
||||
|
||||
let chunk_result = first_chunk.unwrap();
|
||||
assert!(chunk_result.is_ok());
|
||||
|
||||
let chunk = chunk_result.unwrap();
|
||||
assert_eq!(chunk.content_delta(), Some("Hello"));
|
||||
assert!(!chunk.is_final());
|
||||
|
||||
// Test that stream ends properly
|
||||
let final_chunk = streaming_response.next();
|
||||
assert!(final_chunk.is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
//! Arch provider implementation
|
||||
//!
|
||||
//! Arch uses OpenAI-compatible API format
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::ArchProvider;
|
||||
|
|
@ -1,88 +0,0 @@
|
|||
//! Arch provider implementation
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// Arch provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ArchProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for ArchProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for ArchProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for ArchProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for ArchProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
matches!(api_path, "/v1/chat/completions")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
vec!["/v1/chat/completions"]
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
//! Claude provider implementation
|
||||
//!
|
||||
//! Claude will use a different API format in the future (/v1/messages)
|
||||
//! For now, fallback to OpenAI-compatible format
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::ClaudeProvider;
|
||||
|
|
@ -1,93 +0,0 @@
|
|||
//! Claude provider implementation
|
||||
//!
|
||||
//! TODO: Implement Claude-specific API format (/v1/messages) when needed
|
||||
//! For now, uses OpenAI-compatible format as fallback
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// Claude provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClaudeProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for ClaudeProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for ClaudeProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for ClaudeProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for ClaudeProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
// TODO: Update when Claude API is fully implemented
|
||||
matches!(api_path, "/v1/chat/completions" | "/v1/messages")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
// TODO: Update when Claude API is fully implemented
|
||||
vec!["/v1/messages"]
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
//! Deepseek provider implementation
|
||||
//!
|
||||
//! Deepseek uses OpenAI-compatible API format
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::DeepseekProvider;
|
||||
|
|
@ -1,88 +0,0 @@
|
|||
//! Deepseek provider implementation
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// Deepseek provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DeepseekProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for DeepseekProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for DeepseekProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for DeepseekProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for DeepseekProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
matches!(api_path, "/v1/chat/completions")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
vec!["/v1/chat/completions"]
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
//! Gemini provider implementation
|
||||
//!
|
||||
//! Gemini will use a different API format in the future
|
||||
//! For now, fallback to OpenAI-compatible format
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::GeminiProvider;
|
||||
|
|
@ -1,93 +0,0 @@
|
|||
//! Gemini provider implementation
|
||||
//!
|
||||
//! This module contains the Gemini provider that handles Google's Gemini API format
|
||||
//! requests in OpenAI-compatible format.
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// Gemini provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GeminiProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for GeminiProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for GeminiProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for GeminiProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for GeminiProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
// TODO: Update when Gemini API is fully implemented
|
||||
matches!(api_path, "/v1/chat/completions" | "/v1/models")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
// TODO: Update when Gemini API is fully implemented
|
||||
vec!["/v1/models"]
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
//! GitHub provider implementation
|
||||
//!
|
||||
//! GitHub will use a different API format in the future (/models)
|
||||
//! For now, fallback to OpenAI-compatible format
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::GitHubProvider;
|
||||
|
|
@ -1,93 +0,0 @@
|
|||
//! GitHub provider implementation
|
||||
//!
|
||||
//! This module contains the GitHub provider that handles GitHub API format
|
||||
//! requests in OpenAI-compatible format.
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// GitHub provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GitHubProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for GitHubProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for GitHubProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for GitHubProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for GitHubProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
// TODO: Update when GitHub API is fully implemented
|
||||
matches!(api_path, "/v1/chat/completions" | "/models")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
// TODO: Update when GitHub API is fully implemented
|
||||
vec!["/models"]
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
//! Groq provider implementation
|
||||
//!
|
||||
//! Groq uses OpenAI-compatible API format but with different endpoints
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::GroqProvider;
|
||||
|
|
@ -1,91 +0,0 @@
|
|||
//! Groq provider implementation
|
||||
//!
|
||||
//! This module contains the Groq provider that handles Groq API format requests.
|
||||
//! Groq uses OpenAI-compatible format but may have provider-specific nuances.
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// Groq provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GroqProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for GroqProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for GroqProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for GroqProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for GroqProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
matches!(api_path, "/v1/chat/completions" | "/openai/v1/chat/completions")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
vec!["/openai/v1/chat/completions"]
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
//! Mistral provider implementation
|
||||
//!
|
||||
//! Mistral uses OpenAI-compatible API format
|
||||
|
||||
pub mod provider;
|
||||
pub use provider::MistralProvider;
|
||||
|
|
@ -1,88 +0,0 @@
|
|||
//! Mistral provider implementation
|
||||
|
||||
use crate::providers::{ProviderInterface, ConversionMode};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
|
||||
use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse};
|
||||
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError};
|
||||
|
||||
/// Mistral provider implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MistralProvider;
|
||||
|
||||
// Trait implementations that delegate to OpenAI
|
||||
impl ProviderRequest for MistralProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::try_from_bytes(&openai_provider, bytes)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::set_streaming_options(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_messages_text(&openai_provider, request)
|
||||
}
|
||||
|
||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderRequest::extract_user_message(&openai_provider, request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for MistralProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
ProviderResponse::extract_usage_counts(&openai_provider, response)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for MistralProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let openai_provider = OpenAIProvider;
|
||||
StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for MistralProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
matches!(api_path, "/v1/chat/completions")
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
vec!["/v1/chat/completions"]
|
||||
}
|
||||
}
|
||||
|
|
@ -5,24 +5,11 @@
|
|||
|
||||
pub mod traits;
|
||||
pub mod openai;
|
||||
pub mod groq;
|
||||
pub mod mistral;
|
||||
pub mod deepseek;
|
||||
pub mod arch;
|
||||
pub mod gemini;
|
||||
pub mod claude;
|
||||
pub mod github;
|
||||
|
||||
// Re-export the main interfaces
|
||||
pub use traits::*;
|
||||
pub use openai::OpenAIProvider;
|
||||
pub use groq::GroqProvider;
|
||||
pub use mistral::MistralProvider;
|
||||
pub use deepseek::DeepseekProvider;
|
||||
pub use arch::ArchProvider;
|
||||
pub use gemini::GeminiProvider;
|
||||
pub use claude::ClaudeProvider;
|
||||
pub use github::GitHubProvider;
|
||||
// Note: OpenAIProvider has been deprecated in favor of function-based approach
|
||||
// OpenAI functionality is accessed through openai::builder and openai::types modules
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
|
|
@ -69,219 +56,3 @@ impl Display for ProviderId {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderId {
|
||||
/// Get the API endpoint path for this provider
|
||||
pub fn api_path(&self) -> &'static str {
|
||||
match self {
|
||||
ProviderId::OpenAI => "/v1/chat/completions",
|
||||
ProviderId::Groq => "/openai/v1/chat/completions",
|
||||
ProviderId::Gemini => "/v1/models", // TODO: Update when Gemini API is implemented
|
||||
ProviderId::Claude => "/v1/messages", // TODO: Update when Claude API is implemented
|
||||
ProviderId::Mistral => "/v1/chat/completions",
|
||||
ProviderId::Deepseek => "/v1/chat/completions",
|
||||
ProviderId::GitHub => "/models", // TODO: Update when GitHub models API is implemented
|
||||
ProviderId::Arch => "/v1/chat/completions",
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this provider supports OpenAI v1/chat/completions API format
|
||||
pub fn supports_openai_format(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Enum for dynamic dispatch of provider instances
|
||||
pub enum Provider {
|
||||
OpenAI(OpenAIProvider, ProviderId),
|
||||
Groq(GroqProvider, ProviderId),
|
||||
Mistral(MistralProvider, ProviderId),
|
||||
Deepseek(DeepseekProvider, ProviderId),
|
||||
Arch(ArchProvider, ProviderId),
|
||||
Gemini(GeminiProvider, ProviderId),
|
||||
Claude(ClaudeProvider, ProviderId),
|
||||
GitHub(GitHubProvider, ProviderId),
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
/// Create a provider instance from a provider ID
|
||||
pub fn new(id: ProviderId) -> Self {
|
||||
match id {
|
||||
ProviderId::OpenAI => Provider::OpenAI(OpenAIProvider, id),
|
||||
ProviderId::Groq => Provider::Groq(GroqProvider, id),
|
||||
ProviderId::Mistral => Provider::Mistral(MistralProvider, id),
|
||||
ProviderId::Deepseek => Provider::Deepseek(DeepseekProvider, id),
|
||||
ProviderId::Arch => Provider::Arch(ArchProvider, id),
|
||||
ProviderId::Gemini => Provider::Gemini(GeminiProvider, id),
|
||||
ProviderId::Claude => Provider::Claude(ClaudeProvider, id),
|
||||
ProviderId::GitHub => Provider::GitHub(GitHubProvider, id),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the provider ID
|
||||
pub fn id(&self) -> ProviderId {
|
||||
match self {
|
||||
Provider::OpenAI(_, id) => *id,
|
||||
Provider::Groq(_, id) => *id,
|
||||
Provider::Mistral(_, id) => *id,
|
||||
Provider::Deepseek(_, id) => *id,
|
||||
Provider::Arch(_, id) => *id,
|
||||
Provider::Gemini(_, id) => *id,
|
||||
Provider::Claude(_, id) => *id,
|
||||
Provider::GitHub(_, id) => *id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Implement traits directly on the Provider enum
|
||||
impl ProviderRequest for Provider {
|
||||
type Error = openai::provider::OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Self::Error> {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::Groq(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::Mistral(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::Deepseek(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::Arch(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::Gemini(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::Claude(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
Provider::GitHub(provider, _) => ProviderRequest::try_from_bytes(provider, bytes),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::Groq(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::Mistral(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::Deepseek(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::Arch(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::Gemini(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::Claude(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
Provider::GitHub(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a crate::apis::openai::ChatCompletionsRequest) -> &'a str {
|
||||
// Since all providers use the same implementation, just use the first one
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool {
|
||||
// Since all providers use the same implementation, just use the first one
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut crate::apis::openai::ChatCompletionsRequest) {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::Groq(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::Mistral(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::Deepseek(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::Arch(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::Gemini(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::Claude(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
Provider::GitHub(provider, _) => ProviderRequest::set_streaming_options(provider, request),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::Groq(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::Mistral(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::Deepseek(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::Arch(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::Gemini(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::Claude(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
Provider::GitHub(provider, _) => ProviderRequest::extract_messages_text(provider, request),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_user_message(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> Option<String> {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
||||
Provider::Groq(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
||||
Provider::Mistral(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
||||
Provider::Deepseek(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
||||
Provider::Arch(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
||||
Provider::Gemini(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
||||
Provider::Claude(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
||||
Provider::GitHub(provider, _) => ProviderRequest::extract_user_message(provider, request),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for Provider {
|
||||
type Error = openai::provider::OpenAIApiError;
|
||||
type Usage = crate::apis::openai::Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider_id: &super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Self::Error> {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Groq(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Mistral(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Deepseek(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Arch(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Gemini(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Claude(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::GitHub(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
}
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a crate::apis::openai::ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
// Since all providers use the same implementation, just use the direct implementation
|
||||
Some(&response.usage)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for Provider {
|
||||
type Error = openai::provider::OpenAIApiError;
|
||||
type StreamChunk = crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
type StreamingIter = openai::provider::OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider_id: &super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Groq(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Mistral(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Deepseek(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Arch(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Gemini(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::Claude(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
Provider::GitHub(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for Provider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::Groq(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::Mistral(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::Deepseek(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::Arch(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::Gemini(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::Claude(provider, _) => provider.has_compatible_api(api_path),
|
||||
Provider::GitHub(provider, _) => provider.has_compatible_api(api_path),
|
||||
}
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
match self {
|
||||
Provider::OpenAI(provider, _) => provider.supported_apis(),
|
||||
Provider::Groq(provider, _) => provider.supported_apis(),
|
||||
Provider::Mistral(provider, _) => provider.supported_apis(),
|
||||
Provider::Deepseek(provider, _) => provider.supported_apis(),
|
||||
Provider::Arch(provider, _) => provider.supported_apis(),
|
||||
Provider::Gemini(provider, _) => provider.supported_apis(),
|
||||
Provider::Claude(provider, _) => provider.supported_apis(),
|
||||
Provider::GitHub(provider, _) => provider.supported_apis(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
pub mod builder;
|
||||
pub mod types;
|
||||
pub mod provider;
|
||||
|
||||
// Re-export the main provider
|
||||
pub use provider::OpenAIProvider;
|
||||
// Re-export the main types and builder functionality
|
||||
pub use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse};
|
||||
pub use builder::*;
|
||||
pub use types::*;
|
||||
|
||||
// Note: The OpenAIProvider struct has been deprecated in favor of the function-based approach in traits.rs
|
||||
// All provider functionality is now accessed through try_request_from_bytes, try_response_from_bytes, etc.
|
||||
|
|
|
|||
|
|
@ -1,217 +0,0 @@
|
|||
//! OpenAI provider interface implementations
|
||||
|
||||
use crate::apis::openai::*;
|
||||
use crate::providers::traits::*;
|
||||
|
||||
// Simple error type for OpenAI API operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OpenAIApiError {
|
||||
#[error("JSON parsing error: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
#[error("UTF-8 parsing error: {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
#[error("Invalid streaming data: {0}")]
|
||||
InvalidStreamingData(String),
|
||||
#[error("Request conversion error: {0}")]
|
||||
RequestConversionError(String),
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// OpenAI Provider Definition
|
||||
// ============================================================================
|
||||
|
||||
pub struct OpenAIProvider;
|
||||
|
||||
// Create a concrete streaming response type to avoid lifetime issues
|
||||
pub struct OpenAIStreamingResponse {
|
||||
lines: Vec<String>,
|
||||
current_index: usize,
|
||||
}
|
||||
|
||||
impl OpenAIStreamingResponse {
|
||||
fn new(data: String) -> Self {
|
||||
let lines: Vec<String> = data.lines().map(|s| s.to_string()).collect();
|
||||
Self {
|
||||
lines,
|
||||
current_index: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for OpenAIStreamingResponse {
|
||||
type Item = Result<ChatCompletionsStreamResponse, OpenAIApiError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while self.current_index < self.lines.len() {
|
||||
let line = &self.lines[self.current_index];
|
||||
self.current_index += 1;
|
||||
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
let data = data.trim();
|
||||
if data == "[DONE]" {
|
||||
return None;
|
||||
}
|
||||
|
||||
if data == r#"{"type": "ping"}"# {
|
||||
continue; // Skip ping messages
|
||||
}
|
||||
|
||||
return Some(
|
||||
serde_json::from_str::<ChatCompletionsStreamResponse>(data).map_err(|e| {
|
||||
OpenAIApiError::InvalidStreamingData(format!("Error parsing: {}, data: {}", e, data))
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderInterface for OpenAIProvider {
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool {
|
||||
api_path == "/v1/chat/completions"
|
||||
}
|
||||
|
||||
fn supported_apis(&self) -> Vec<&'static str> {
|
||||
vec!["/v1/chat/completions"]
|
||||
}
|
||||
}
|
||||
|
||||
// Direct trait implementations on OpenAIProvider
|
||||
impl ProviderRequest for OpenAIProvider {
|
||||
type Error = OpenAIApiError;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(serde_json::from_str(s)?)
|
||||
}
|
||||
|
||||
fn to_provider_bytes(&self, request: &ChatCompletionsRequest, _provider: super::super::ProviderId, _mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
|
||||
Ok(serde_json::to_vec(request)?)
|
||||
}
|
||||
|
||||
fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str {
|
||||
&request.model
|
||||
}
|
||||
|
||||
fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool {
|
||||
request.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) {
|
||||
if request.stream_options.is_none() {
|
||||
request.stream_options = Some(StreamOptions {
|
||||
include_usage: Some(true),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String {
|
||||
request.messages
|
||||
.iter()
|
||||
.fold(String::new(), |acc, m| {
|
||||
acc + " " + &match &m.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => {
|
||||
parts.iter().map(|part| match part {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
|
||||
}).collect::<Vec<_>>().join(" ")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option<String> {
|
||||
request.messages.last().and_then(|msg| {
|
||||
match &msg.content {
|
||||
MessageContent::Text(text) => Some(text.clone()),
|
||||
MessageContent::Parts(parts) => {
|
||||
// Extract text from content parts, ignoring images
|
||||
let text_parts: Vec<String> = parts
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.clone()),
|
||||
ContentPart::ImageUrl { .. } => None,
|
||||
})
|
||||
.collect();
|
||||
if text_parts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text_parts.join(" "))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for OpenAIProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type Usage = Usage;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<ChatCompletionsResponse, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(serde_json::from_str(s)?)
|
||||
}
|
||||
|
||||
fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> {
|
||||
Some(&response.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
Some((
|
||||
response.usage.prompt_tokens as usize,
|
||||
response.usage.completion_tokens as usize,
|
||||
response.usage.total_tokens as usize,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for OpenAIProvider {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamChunk = ChatCompletionsStreamResponse;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(OpenAIStreamingResponse::new(s.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Trait Implementations for OpenAI Types (Keep for TokenUsage only)
|
||||
// ============================================================================
|
||||
|
||||
impl TokenUsage for Usage {
|
||||
fn completion_tokens(&self) -> usize {
|
||||
self.completion_tokens as usize
|
||||
}
|
||||
|
||||
fn prompt_tokens(&self) -> usize {
|
||||
self.prompt_tokens as usize
|
||||
}
|
||||
|
||||
fn total_tokens(&self) -> usize {
|
||||
self.total_tokens as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamChunk for ChatCompletionsStreamResponse {
|
||||
type Usage = Usage;
|
||||
|
||||
fn usage(&self) -> Option<&Self::Usage> {
|
||||
self.usage.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingResponse for OpenAIStreamingResponse {
|
||||
type Error = OpenAIApiError;
|
||||
type StreamChunk = ChatCompletionsStreamResponse;
|
||||
type StreamingIter = OpenAIStreamingResponse;
|
||||
|
||||
fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(OpenAIStreamingResponse::new(s.to_string()))
|
||||
}
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@
|
|||
//! handling of LLM requests and responses in the gateway.
|
||||
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
/// Conversion mode for provider requests/responses
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
|
@ -14,30 +15,41 @@ pub enum ConversionMode {
|
|||
Passthrough,
|
||||
}
|
||||
|
||||
/// Trait for provider-specific request types
|
||||
pub trait ProviderRequest {
|
||||
type Error: Error + Send + Sync + 'static;
|
||||
/// Error types for provider operations
|
||||
#[derive(Debug)]
|
||||
pub struct ProviderRequestError {
|
||||
pub message: String,
|
||||
pub source: Option<Box<dyn Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
/// Parse request from raw bytes
|
||||
fn try_from_bytes(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Self::Error>;
|
||||
#[derive(Debug)]
|
||||
pub struct ProviderResponseError {
|
||||
pub message: String,
|
||||
pub source: Option<Box<dyn Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
/// Convert to provider-specific format
|
||||
fn to_provider_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error>;
|
||||
impl fmt::Display for ProviderRequestError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Provider request error: {}", self.message)
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the model name from the request
|
||||
fn extract_model<'a>(&self, request: &'a crate::apis::openai::ChatCompletionsRequest) -> &'a str;
|
||||
impl fmt::Display for ProviderResponseError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Provider response error: {}", self.message)
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this is a streaming request
|
||||
fn is_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool;
|
||||
impl Error for ProviderRequestError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
}
|
||||
}
|
||||
|
||||
/// Set streaming options (e.g., include_usage)
|
||||
fn set_streaming_options(&self, request: &mut crate::apis::openai::ChatCompletionsRequest);
|
||||
|
||||
/// Extract text content from messages for token counting
|
||||
fn extract_messages_text(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String;
|
||||
|
||||
/// Extract the user message for tracing/logging purposes
|
||||
fn extract_user_message(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> Option<String>;
|
||||
impl Error for ProviderResponseError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for token usage information
|
||||
|
|
@ -47,46 +59,178 @@ pub trait TokenUsage {
|
|||
fn total_tokens(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Trait for provider-specific request types
|
||||
pub trait ProviderRequest: Send + Sync {
|
||||
/// Extract the model name from the request
|
||||
fn model(&self) -> &str;
|
||||
|
||||
/// Check if this is a streaming request
|
||||
fn is_streaming(&self) -> bool;
|
||||
|
||||
/// Set streaming options (e.g., include_usage)
|
||||
fn set_streaming_options(&mut self);
|
||||
|
||||
/// Extract text content from messages for token counting
|
||||
fn extract_messages_text(&self) -> String;
|
||||
|
||||
/// Extract the user message for tracing/logging purposes
|
||||
fn extract_user_message(&self) -> Option<String>;
|
||||
|
||||
/// Convert to provider-specific format
|
||||
fn to_provider_bytes(&self, mode: ConversionMode) -> Result<Vec<u8>, ProviderRequestError>;
|
||||
}
|
||||
|
||||
/// Trait for provider-specific response types
|
||||
pub trait ProviderResponse {
|
||||
type Error: Error + Send + Sync + 'static;
|
||||
type Usage: TokenUsage;
|
||||
|
||||
/// Parse response from raw bytes
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Self::Error>;
|
||||
|
||||
/// Get usage information if available
|
||||
fn usage<'a>(&self, response: &'a crate::apis::openai::ChatCompletionsResponse) -> Option<&'a Self::Usage>;
|
||||
pub trait ProviderResponse: Send + Sync {
|
||||
/// Get usage information if available - returns dynamic trait object
|
||||
fn usage(&self) -> Option<&dyn TokenUsage>;
|
||||
|
||||
/// Extract token counts for metrics
|
||||
fn extract_usage_counts(&self, response: &crate::apis::openai::ChatCompletionsResponse) -> Option<(usize, usize, usize)> {
|
||||
self.usage(response).map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for streaming response chunks
|
||||
pub trait StreamChunk {
|
||||
type Usage: TokenUsage;
|
||||
/// Trait for provider-specific streaming response types
|
||||
pub trait ProviderStreamResponse: Send + Sync {
|
||||
/// Get the content delta for this chunk
|
||||
fn content_delta(&self) -> Option<&str>;
|
||||
|
||||
/// Get usage information if available
|
||||
fn usage(&self) -> Option<&Self::Usage>;
|
||||
/// Check if this is the final chunk in the stream
|
||||
fn is_final(&self) -> bool;
|
||||
|
||||
/// Get role information if available
|
||||
fn role(&self) -> Option<&str>;
|
||||
}
|
||||
|
||||
/// Trait for streaming response iterators
|
||||
pub trait StreamingResponse {
|
||||
type Error: Error + Send + Sync + 'static;
|
||||
type StreamChunk: StreamChunk;
|
||||
type StreamingIter: Iterator<Item = Result<Self::StreamChunk, Self::Error>>;
|
||||
|
||||
/// Parse streaming response from raw bytes
|
||||
fn try_from_bytes(&self, bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<Self::StreamingIter, Self::Error>;
|
||||
///
|
||||
/// This trait ensures that implementing types are iterators that yield
|
||||
/// ProviderStreamResponse results.
|
||||
pub trait ProviderStreamResponseIter: Iterator<Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>> + Send + Sync {
|
||||
// No additional methods needed - just the Iterator constraint with proper bounds
|
||||
}
|
||||
|
||||
/// Main provider interface trait - simplified to only essential methods
|
||||
pub trait ProviderInterface: ProviderRequest + ProviderResponse + StreamingResponse {
|
||||
/// Check if this provider has a compatible API with the client request
|
||||
fn has_compatible_api(&self, api_path: &str) -> bool;
|
||||
// ============================================================================
|
||||
// PROVIDER FUNCTIONS - NO TRAITS, JUST PARAMETERIZED CONVERSION
|
||||
// ============================================================================
|
||||
//
|
||||
// ARCHITECTURAL DECISION: Function-based Provider API
|
||||
//
|
||||
// We chose this function-based approach over the original ProviderInterface trait
|
||||
// for several critical reasons:
|
||||
//
|
||||
// 1. TRAIT OBJECT LIMITATION:
|
||||
// - The original ProviderInterface had associated types (Request, Response, etc.)
|
||||
// - Traits with associated types cannot be used as trait objects (Box<dyn ProviderInterface>)
|
||||
// - This prevented dynamic provider selection at runtime based on request headers
|
||||
// - Error: "the trait `ProviderInterface` cannot be made into an object"
|
||||
//
|
||||
// 2. DYNAMIC PROVIDER SELECTION REQUIREMENT:
|
||||
// - The gateway needs to select providers dynamically based on incoming headers
|
||||
// - Cannot know provider type at compile time - must dispatch at runtime
|
||||
// - Need ability to return generic trait objects that work polymorphically
|
||||
//
|
||||
// 3. WRAPPER TYPE ELIMINATION:
|
||||
// - Original design required wrapper types like OpenAIRequestWrapper, OpenAIResponseWrapper
|
||||
// - User wanted to implement traits directly on concrete types (ChatCompletionsRequest, etc.)
|
||||
// - Function-based approach allows direct trait implementations without wrappers
|
||||
//
|
||||
// 4. PARAMETERIZED CONVERSION PATTERN:
|
||||
// - Follows existing codebase pattern: TryFrom<(&[u8], &ProviderId)>
|
||||
// - Enables runtime provider selection while maintaining type safety
|
||||
// - Single implementation can handle multiple OpenAI-compatible providers
|
||||
//
|
||||
// 5. TYPE ERASURE FOR GENERIC INTERFACE:
|
||||
// - Functions return Box<dyn ProviderRequest/Response> - works as trait objects
|
||||
// - stream_context.rs can work with generic interfaces without knowing concrete types
|
||||
// - Maintains polymorphism while enabling dynamic dispatch
|
||||
// ============================================================================
|
||||
|
||||
/// Get supported API endpoints for this provider
|
||||
fn supported_apis(&self) -> Vec<&'static str>;
|
||||
use crate::ProviderId;
|
||||
|
||||
/// Parse request from bytes using provider ID - returns generic ProviderRequest trait object
|
||||
pub fn try_request_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result<Box<dyn ProviderRequest>, ProviderRequestError> {
|
||||
match provider_id {
|
||||
// All these providers currently use OpenAI-compatible chat completions API
|
||||
// In the future, we can add provider-specific handling in separate match arms
|
||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||
|
||||
let request = crate::apis::openai::ChatCompletionsRequest::try_from((bytes, provider_id))
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to parse request: {}", e),
|
||||
source: Some(Box::new(e)),
|
||||
})?;
|
||||
|
||||
// Return as trait object - this enables polymorphic usage
|
||||
// ChatCompletionsRequest implements ProviderRequest directly (no wrapper needed)
|
||||
Ok(Box::new(request) as Box<dyn ProviderRequest>)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse response from bytes using provider ID - returns generic ProviderResponse trait object
|
||||
pub fn try_response_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result<Box<dyn ProviderResponse>, ProviderResponseError> {
|
||||
match provider_id {
|
||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||
// Parameterized conversion allows provider-specific response parsing
|
||||
let response = crate::apis::openai::ChatCompletionsResponse::try_from((bytes, provider_id))
|
||||
.map_err(|e| ProviderResponseError {
|
||||
message: format!("Failed to parse response: {}", e),
|
||||
source: Some(Box::new(e)),
|
||||
})?;
|
||||
|
||||
// ChatCompletionsResponse implements ProviderResponse directly - no wrapper needed!
|
||||
Ok(Box::new(response) as Box<dyn ProviderResponse>)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create streaming response using provider ID - returns clean ProviderStreamResponseIter trait object
|
||||
///
|
||||
/// This function returns a ProviderStreamResponseIter that's just an iterator,
|
||||
/// eliminating the complex nested Result<Box<dyn Iterator<...>>> type completely.
|
||||
pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result<Box<dyn ProviderStreamResponseIter>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match provider_id {
|
||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||
// Parse SSE (Server-Sent Events) streaming data
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
let iter = crate::apis::openai::SseChatCompletionIter::new(lines.into_iter());
|
||||
|
||||
// Return the iterator directly - it implements ProviderStreamResponseIter
|
||||
Ok(Box::new(iter))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if provider has compatible API
|
||||
///
|
||||
/// Replaces the old ProviderInterface::has_compatible_api method.
|
||||
/// This function enables runtime API compatibility checking without needing a provider instance.
|
||||
pub fn has_compatible_api(provider_id: &ProviderId, api_path: &str) -> bool {
|
||||
match provider_id {
|
||||
// Currently all these providers support OpenAI chat completions API
|
||||
// Future providers with different APIs will get their own match arms
|
||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||
api_path == "/v1/chat/completions"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get supported APIs for provider
|
||||
///
|
||||
/// Replaces the old ProviderInterface::supported_apis method.
|
||||
/// Returns a static list of supported API endpoints for the given provider.
|
||||
pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> {
|
||||
match provider_id {
|
||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||
vec!["/v1/chat/completions"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ use common::ratelimit::Header;
|
|||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::providers::traits::{
|
||||
ProviderRequest, ProviderResponse, StreamChunk, StreamingResponse, TokenUsage,
|
||||
use hermesllm::{
|
||||
try_request_from_bytes, try_response_from_bytes, try_streaming_from_bytes, ConversionMode,
|
||||
ProviderId,
|
||||
};
|
||||
use hermesllm::{ConversionMode, Provider, ProviderId};
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
|
|
@ -79,8 +79,8 @@ impl StreamContext {
|
|||
.expect("the provider should be set when asked for it")
|
||||
}
|
||||
|
||||
fn get_provider(&self) -> Provider {
|
||||
self.llm_provider().create_provider()
|
||||
fn get_provider_id(&self) -> ProviderId {
|
||||
self.llm_provider().to_provider_id()
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
|
|
@ -298,9 +298,9 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let provider = self.get_provider();
|
||||
let provider_id = self.get_provider_id();
|
||||
|
||||
let mut deserialized_body = match ProviderRequest::try_from_bytes(&provider, &body_bytes) {
|
||||
let mut deserialized_body = match try_request_from_bytes(&body_bytes, &provider_id) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
|
|
@ -329,10 +329,10 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
// Use the provider interface methods for cleaner interaction
|
||||
let model_requested = provider.extract_model(&deserialized_body).to_string(); // Convert to owned string
|
||||
let model_requested = deserialized_body.model().to_string(); // Convert to owned string
|
||||
|
||||
// Extract user message for tracing
|
||||
self.user_message = provider.extract_user_message(&deserialized_body);
|
||||
self.user_message = deserialized_body.extract_user_message();
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||
|
|
@ -342,15 +342,15 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
|
||||
// Use provider interface for streaming detection and setup
|
||||
self.streaming_response = provider.is_streaming(&deserialized_body);
|
||||
self.streaming_response = deserialized_body.is_streaming();
|
||||
|
||||
// Set streaming options if needed
|
||||
if self.streaming_response {
|
||||
provider.set_streaming_options(&mut deserialized_body);
|
||||
deserialized_body.set_streaming_options();
|
||||
}
|
||||
|
||||
// Use provider interface for text extraction (after potential mutation)
|
||||
let input_tokens_str = provider.extract_messages_text(&deserialized_body);
|
||||
let input_tokens_str = deserialized_body.extract_messages_text();
|
||||
// enforce ratelimits on ingress
|
||||
if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) {
|
||||
self.send_server_error(
|
||||
|
|
@ -365,21 +365,18 @@ impl HttpContext for StreamContext {
|
|||
let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
|
||||
|
||||
// Convert chat completion request to llm provider specific request using provider interface
|
||||
let deserialized_body_bytes = match provider.to_provider_bytes(
|
||||
&deserialized_body,
|
||||
provider.id(),
|
||||
ConversionMode::Compatible,
|
||||
) {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize request body: {}", e);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Request serialization error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
let deserialized_body_bytes =
|
||||
match deserialized_body.to_provider_bytes(ConversionMode::Compatible) {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize request body: {}", e);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Request serialization error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
self.set_http_request_body(0, body_size, &deserialized_body_bytes);
|
||||
|
||||
|
|
@ -550,16 +547,9 @@ impl HttpContext for StreamContext {
|
|||
|
||||
// Parse streaming response using OpenAI-compatible format
|
||||
// Since all providers use OpenAI-compatible streaming format
|
||||
let provider = self.get_provider();
|
||||
let provider_id =
|
||||
ProviderId::from(self.llm_provider().provider_interface.to_string().as_str());
|
||||
let provider_id = self.get_provider_id();
|
||||
|
||||
match StreamingResponse::try_from_bytes(
|
||||
&provider,
|
||||
&body,
|
||||
&provider_id,
|
||||
ConversionMode::Compatible,
|
||||
) {
|
||||
match try_streaming_from_bytes(&body, &provider_id, ConversionMode::Compatible) {
|
||||
Ok(mut streaming_response) => {
|
||||
// Process each streaming chunk
|
||||
while let Some(chunk_result) = streaming_response.next() {
|
||||
|
|
@ -587,14 +577,20 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
// Extract usage information if available
|
||||
if let Some(usage) = chunk.usage() {
|
||||
let completion_tokens = usage.completion_tokens();
|
||||
self.response_tokens += completion_tokens;
|
||||
debug!(
|
||||
"Streaming chunk completion tokens: {}",
|
||||
completion_tokens
|
||||
);
|
||||
// For streaming responses, we handle token counting differently
|
||||
// The ProviderStreamResponse trait provides content_delta, is_final, and role
|
||||
// Token counting for streaming responses typically happens with final usage chunk
|
||||
if chunk.is_final() {
|
||||
// For now, we'll implement basic token estimation
|
||||
// In a complete implementation, the final chunk would contain usage information
|
||||
debug!("Received final streaming chunk");
|
||||
}
|
||||
|
||||
// For now, estimate tokens from content delta
|
||||
if let Some(content) = chunk.content_delta() {
|
||||
// Rough estimation: ~4 characters per token
|
||||
let estimated_tokens = content.len() / 4;
|
||||
self.response_tokens += estimated_tokens.max(1);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
|
|
@ -605,40 +601,37 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse streaming response: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
let provider = self.get_provider();
|
||||
let response = match ProviderResponse::try_from_bytes(
|
||||
&provider,
|
||||
&body,
|
||||
&provider.id(),
|
||||
ConversionMode::Compatible,
|
||||
) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
debug!(
|
||||
"on_http_response_body: S[{}], response body: {}",
|
||||
self.context_id,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
let provider_id = self.get_provider_id();
|
||||
let response =
|
||||
match try_response_from_bytes(&body, &provider_id, ConversionMode::Compatible) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
debug!(
|
||||
"on_http_response_body: S[{}], response body: {}",
|
||||
self.context_id,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Use provider interface to extract usage information
|
||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
||||
provider.extract_usage_counts(&response)
|
||||
response.extract_usage_counts()
|
||||
{
|
||||
debug!(
|
||||
"Response usage: prompt={}, completion={}, total={}",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue