fixing test cases, and making sure all references to the ChatCOmpletions* objects point to the new types

This commit is contained in:
Salman Paracha 2025-08-11 22:42:13 -07:00
parent df32c7e278
commit 7253a0f203
15 changed files with 224 additions and 838 deletions

View file

@ -2,6 +2,8 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use std::fmt::Display;
use thiserror::Error;
use crate::{providers::ProviderRequestError, ConversionMode, ProviderRequest};
use super::ApiDefinition;
@ -116,8 +118,8 @@ pub enum Role {
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Message {
pub content: MessageContent,
pub role: Role,
pub content: MessageContent,
pub name: Option<String>,
/// Tool calls made by the assistant (only present for assistant role)
pub tool_calls: Option<Vec<ToolCall>>,
@ -171,6 +173,28 @@ pub enum MessageContent {
Parts(Vec<ContentPart>),
}
impl Display for MessageContent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MessageContent::Text(text) => write!(f, "{}", text),
MessageContent::Parts(parts) => {
let text_parts: Vec<String> = parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
ContentPart::ImageUrl { .. } => {
// skip image URLs or their data in text representation
None
}
})
.collect();
let combined_text = text_parts.join("\n");
write!(f, "{}", combined_text)
}
}
}
}
/// Individual content part within a message (text or image)
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
@ -560,6 +584,26 @@ impl TokenUsage for Usage {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelDetail {
pub id: String,
pub object: String,
pub created: usize,
pub owned_by: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelObject {
#[serde(rename = "list")]
List,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Models {
pub object: ModelObject,
pub data: Vec<ModelDetail>,
}
// Error type for streaming operations
#[derive(Debug, thiserror::Error)]
pub enum OpenAIStreamError {
@ -571,6 +615,22 @@ pub enum OpenAIStreamError {
InvalidStreamingData(String),
}
#[derive(Debug, Error)]
pub enum OpenAIError {
#[error("json error: {0}")]
JsonParseError(#[from] serde_json::Error),
#[error("utf8 parsing error: {0}")]
Utf8Error(#[from] std::str::Utf8Error),
#[error("invalid streaming data err {source}, data: {data}")]
InvalidStreamingData {
source: serde_json::Error,
data: String,
},
#[error("unsupported provider: {provider}")]
UnsupportedProvider { provider: String },
}
/// SSE-based streaming iterator for OpenAI chat completions
/// Implements ProviderStreamResponseIter directly
pub struct SseChatCompletionIter<I>

View file

@ -1,114 +0,0 @@
use serde_json::Value;
use crate::providers::openai::types::{ChatCompletionsRequest, Message, StreamOptions};
#[derive(Debug, Clone)]
pub struct OpenAIRequestBuilder {
model: String,
messages: Vec<Message>,
temperature: Option<f32>,
top_p: Option<f32>,
n: Option<u32>,
max_tokens: Option<u32>,
stream: Option<bool>,
stop: Option<Vec<String>>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
stream_options: Option<StreamOptions>,
tools: Option<Vec<Value>>,
}
impl OpenAIRequestBuilder {
pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
Self {
model: model.into(),
messages,
temperature: None,
top_p: None,
n: None,
max_tokens: None,
stream: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
stream_options: None,
tools: None,
}
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn n(mut self, n: u32) -> Self {
self.n = Some(n);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn stream(mut self, stream: bool) -> Self {
self.stream = Some(stream);
self
}
pub fn stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
self.presence_penalty = Some(presence_penalty);
self
}
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.frequency_penalty = Some(frequency_penalty);
self
}
pub fn stream_options(mut self, include_usage: bool) -> Self {
self.stream = Some(true);
self.stream_options = Some(StreamOptions { include_usage });
self
}
pub fn tools(mut self, tools: Vec<Value>) -> Self {
self.tools = Some(tools);
self
}
pub fn build(self) -> Result<ChatCompletionsRequest, &'static str> {
let request = ChatCompletionsRequest {
model: self.model,
messages: self.messages,
temperature: self.temperature,
top_p: self.top_p,
n: self.n,
max_tokens: self.max_tokens,
stream: self.stream,
stop: self.stop,
presence_penalty: self.presence_penalty,
frequency_penalty: self.frequency_penalty,
stream_options: self.stream_options,
tools: self.tools,
metadata: None,
};
Ok(request)
}
}
impl ChatCompletionsRequest {
pub fn builder(model: impl Into<String>, messages: Vec<Message>) -> OpenAIRequestBuilder {
OpenAIRequestBuilder::new(model, messages)
}
}

View file

@ -1,10 +1,5 @@
pub mod builder;
pub mod types;
// Re-export the main types and builder functionality
pub use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse};
pub use builder::*;
pub use types::*;
// Note: The OpenAIProvider struct has been deprecated in favor of the function-based approach in traits.rs
// All provider functionality is now accessed through try_request_from_bytes, try_response_from_bytes, etc.

View file

@ -1,552 +0,0 @@
use std::collections::HashMap;
use std::fmt::Display;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::skip_serializing_none;
use std::convert::TryFrom;
use std::str;
use thiserror::Error;
use crate::providers::ProviderId;
#[derive(Debug, Error)]
pub enum OpenAIError {
#[error("json error: {0}")]
JsonParseError(#[from] serde_json::Error),
#[error("utf8 parsing error: {0}")]
Utf8Error(#[from] std::str::Utf8Error),
#[error("invalid streaming data err {source}, data: {data}")]
InvalidStreamingData {
source: serde_json::Error,
data: String,
},
#[error("unsupported provider: {provider}")]
UnsupportedProvider { provider: String },
}
type Result<T> = std::result::Result<T, OpenAIError>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum MultiPartContentType {
#[serde(rename = "text")]
Text,
#[serde(rename = "image_url")]
ImageUrl,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ImageUrl {
pub url: String,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct MultiPartContent {
pub text: Option<String>,
pub image_url: Option<ImageUrl>,
#[serde(rename = "type")]
pub content_type: MultiPartContentType,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum ContentType {
Text(String),
MultiPart(Vec<MultiPartContent>),
}
impl Display for ContentType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ContentType::Text(text) => write!(f, "{}", text),
ContentType::MultiPart(multi_part) => {
let text_parts: Vec<String> = multi_part
.iter()
.filter_map(|part| {
if part.content_type == MultiPartContentType::Text {
part.text.clone()
} else if part.content_type == MultiPartContentType::ImageUrl {
// skip image URLs or their data in text representation
None
} else {
panic!("Unsupported content type: {:?}", part.content_type);
}
})
.collect();
let combined_text = text_parts.join("\n");
write!(f, "{}", combined_text)
}
}
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: Option<ContentType>,
}
impl Message {
pub fn new(content: String) -> Self {
Self {
role: "user".to_string(),
content: Some(ContentType::Text(content)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
pub include_usage: bool,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ChatCompletionsRequest {
pub model: String,
pub messages: Vec<Message>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub n: Option<u32>,
pub max_tokens: Option<u32>,
pub stream: Option<bool>,
pub stop: Option<Vec<String>>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub stream_options: Option<StreamOptions>,
pub tools: Option<Vec<Value>>,
pub metadata: Option<HashMap<String, Value>>,
}
impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIError;
fn try_from(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(OpenAIError::from)
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionsResponse {
pub id: String,
pub object: String,
pub created: u64,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
}
impl TryFrom<&[u8]> for ChatCompletionsResponse {
type Error = OpenAIError;
fn try_from(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(OpenAIError::from)
}
}
impl<'a> TryFrom<(&'a [u8], &'a ProviderId)> for ChatCompletionsResponse {
type Error = OpenAIError;
fn try_from(input: (&'a [u8], &'a ProviderId)) -> Result<Self> {
// Use input.provider as needed, if necessary
serde_json::from_slice(input.0).map_err(OpenAIError::from)
}
}
impl ChatCompletionsRequest {
pub fn to_bytes(&self, provider: ProviderId) -> Result<Vec<u8>> {
match provider {
ProviderId::OpenAI
| ProviderId::Arch
| ProviderId::Deepseek
| ProviderId::Mistral
| ProviderId::Groq
| ProviderId::Gemini
| ProviderId::Claude
| ProviderId::GitHub => serde_json::to_vec(self).map_err(OpenAIError::from),
}
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Choice {
pub index: u32,
pub message: Message,
pub finish_reason: Option<String>,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeltaMessage {
pub role: Option<String>,
pub content: Option<ContentType>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamChoice {
pub index: u32,
pub delta: DeltaMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionStreamResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<StreamChoice>,
pub usage: Option<Usage>,
}
pub struct SseChatCompletionIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
lines: I,
}
impl<I> SseChatCompletionIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
pub fn new(lines: I) -> Self {
Self { lines }
}
}
impl<I> Iterator for SseChatCompletionIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
type Item = Result<ChatCompletionStreamResponse>;
fn next(&mut self) -> Option<Self::Item> {
for line in &mut self.lines {
let line = line.as_ref();
if let Some(data) = line.strip_prefix("data: ") {
let data = data.trim();
if data == "[DONE]" {
return None;
}
if data == r#"{"type": "ping"}"# {
continue; // Skip ping messages - that is usually from anthropic
}
return Some(
serde_json::from_str::<ChatCompletionStreamResponse>(data).map_err(|e| {
OpenAIError::InvalidStreamingData {
source: e,
data: data.to_string(),
}
}),
);
}
}
None
}
}
impl<'a> TryFrom<&'a [u8]> for SseChatCompletionIter<str::Lines<'a>> {
type Error = OpenAIError;
fn try_from(bytes: &'a [u8]) -> Result<Self> {
let s = std::str::from_utf8(bytes)?;
Ok(SseChatCompletionIter::new(s.lines()))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelDetail {
pub id: String,
pub object: String,
pub created: usize,
pub owned_by: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelObject {
#[serde(rename = "list")]
List,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Models {
pub object: ModelObject,
pub data: Vec<ModelDetail>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_content_type_display() {
let text_content = ContentType::Text("Hello, world!".to_string());
assert_eq!(text_content.to_string(), "Hello, world!");
let multi_part_content = ContentType::MultiPart(vec![
MultiPartContent {
text: Some("This is a text part.".to_string()),
content_type: MultiPartContentType::Text,
image_url: None,
},
MultiPartContent {
text: Some("https://example.com/image.png".to_string()),
content_type: MultiPartContentType::ImageUrl,
image_url: None,
},
]);
assert_eq!(multi_part_content.to_string(), "This is a text part.");
}
#[test]
fn test_chat_completions_request_text_type_array() {
const CHAT_COMPLETIONS_REQUEST: &str = r#"
{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What city do you want to know the weather for?"
},
{
"type": "text",
"text": "hello world"
}
]
}
]
}
"#;
let chat_completions_request: ChatCompletionsRequest =
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
if let Some(ContentType::MultiPart(multi_part_content)) =
chat_completions_request.messages[0].content.as_ref()
{
assert_eq!(multi_part_content.len(), 2);
assert_eq!(
multi_part_content[0].content_type,
MultiPartContentType::Text
);
assert_eq!(
multi_part_content[0].text,
Some("What city do you want to know the weather for?".to_string())
);
assert_eq!(
multi_part_content[1].content_type,
MultiPartContentType::Text
);
assert_eq!(multi_part_content[1].text, Some("hello world".to_string()));
} else {
panic!("Expected MultiPartContent");
}
}
#[test]
fn test_chat_completions_request_image_content() {
const CHAT_COMPLETIONS_REQUEST: &str = r#"
{
"stream": true,
"model": "openai/gpt-4o",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "describe this photo pls"
},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/...=="
}
}
]
}
]
}"#;
let chat_completions_request: ChatCompletionsRequest =
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
assert_eq!(chat_completions_request.model, "openai/gpt-4o");
if let Some(ContentType::MultiPart(multi_part_content)) =
chat_completions_request.messages[0].content.as_ref()
{
assert_eq!(multi_part_content.len(), 2);
assert_eq!(
multi_part_content[0].content_type,
MultiPartContentType::Text
);
assert_eq!(
multi_part_content[0].text,
Some("describe this photo pls".to_string())
);
assert_eq!(
multi_part_content[1].content_type,
MultiPartContentType::ImageUrl
);
assert_eq!(
multi_part_content[1].image_url,
Some(ImageUrl {
url: "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/...==".to_string(),
})
);
} else {
panic!("Expected MultiPartContent");
}
}
#[test]
fn test_sse_streaming() {
let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]}
data: [DONE]"#;
let iter = SseChatCompletionIter::new(json_data.lines());
println!("Testing SSE Streaming");
for item in iter {
match item {
Ok(response) => {
println!("Received response: {:?}", response);
if response.choices.is_empty() {
continue;
}
for choice in response.choices {
if let Some(content) = choice.delta.content {
println!("Content: {}", content);
}
}
}
Err(e) => {
println!("Error parsing JSON: {}", e);
return;
}
}
}
}
#[test]
fn test_sse_streaming_try_from_bytes() {
let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]}
data: [DONE]"#;
let iter = SseChatCompletionIter::try_from(json_data.as_bytes())
.expect("Failed to create SSE iterator");
println!("Testing SSE Streaming");
for item in iter {
match item {
Ok(response) => {
println!("Received response: {:?}", response);
if response.choices.is_empty() {
continue;
}
for choice in response.choices {
if let Some(content) = choice.delta.content {
println!("Content: {}", content);
}
}
}
Err(e) => {
println!("Error parsing JSON: {}", e);
return;
}
}
}
}
#[test]
fn parse_chat_completions_request() {
const CHAT_COMPLETIONS_REQUEST: &str = r#"
{
"model": "None",
"messages": [
{
"role": "user",
"content": "how is the weather in seattle"
}
],
"stream": true
} "#;
let _chat_completions_request: ChatCompletionsRequest =
ChatCompletionsRequest::try_from(CHAT_COMPLETIONS_REQUEST.as_bytes())
.expect("Failed to parse ChatCompletionsRequest");
}
#[test]
fn stream_chunk_parse_claude() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"role":"assistant"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
data: {"type": "ping"}
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":"Hello!"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" How can I assist you today? Whether"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" you have a question, need information"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":", or just want to chat about"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" something, I'm here to help. What woul"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":"d you like to talk about?"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
data: [DONE]
"#;
let iter = SseChatCompletionIter::try_from(CHUNK_RESPONSE.as_bytes());
assert!(iter.is_ok(), "Failed to create SSE iterator");
let iter: SseChatCompletionIter<str::Lines<'_>> = iter.unwrap();
let all_text: Vec<String> = iter
.map(|item| {
let response = item.expect("Failed to parse response");
response
.choices
.into_iter()
.filter_map(|choice| choice.delta.content)
.map(|content| content.to_string())
.collect::<String>()
})
.collect();
assert_eq!(
all_text.len(),
8,
"Expected 8 chunks of text, but got {}",
all_text.len()
);
assert_eq!(
all_text.join(""),
"Hello! How can I assist you today? Whether you have a question, need information, or just want to chat about something, I'm here to help. What would you like to talk about?"
);
}
}

View file

@ -6,59 +6,6 @@
use std::error::Error;
use std::fmt;
/// Conversion mode for provider requests/responses
#[derive(Debug, Clone, Copy)]
pub enum ConversionMode {
/// Compatible: Convert between different provider formats to ensure compatibility
Compatible,
/// Passthrough: Pass requests/responses through with minimal modification
Passthrough,
}
/// Error types for provider operations
#[derive(Debug)]
pub struct ProviderRequestError {
pub message: String,
pub source: Option<Box<dyn Error + Send + Sync>>,
}
#[derive(Debug)]
pub struct ProviderResponseError {
pub message: String,
pub source: Option<Box<dyn Error + Send + Sync>>,
}
impl fmt::Display for ProviderRequestError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Provider request error: {}", self.message)
}
}
impl fmt::Display for ProviderResponseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Provider response error: {}", self.message)
}
}
impl Error for ProviderRequestError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
}
}
impl Error for ProviderResponseError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
}
}
/// Trait for token usage information
pub trait TokenUsage {
fn completion_tokens(&self) -> usize;
fn prompt_tokens(&self) -> usize;
fn total_tokens(&self) -> usize;
}
/// Trait for provider-specific request types
pub trait ProviderRequest: Send + Sync {
/// Extract the model name from the request
@ -107,13 +54,26 @@ pub trait ProviderStreamResponse: Send + Sync {
}
/// Trait for streaming response iterators
///
/// This trait ensures that implementing types are iterators that yield
/// ProviderStreamResponse results.
pub trait ProviderStreamResponseIter: Iterator<Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>> + Send + Sync {
// No additional methods needed - just the Iterator constraint with proper bounds
}
/// Conversion mode for provider requests/responses
#[derive(Debug, Clone, Copy)]
pub enum ConversionMode {
/// Compatible: Convert between different provider formats to ensure compatibility
Compatible,
/// Passthrough: Pass requests/responses through with minimal modification
Passthrough,
}
/// Trait for token usage information
pub trait TokenUsage {
fn completion_tokens(&self) -> usize;
fn prompt_tokens(&self) -> usize;
fn total_tokens(&self) -> usize;
}
// ============================================================================
// PROVIDER FUNCTIONS - NO TRAITS, JUST PARAMETERIZED CONVERSION
// ============================================================================
@ -152,14 +112,42 @@ pub trait ProviderStreamResponseIter: Iterator<Item = Result<Box<dyn ProviderStr
use crate::ProviderId;
/// Parse request from bytes using provider ID - returns generic ProviderRequest trait object
pub fn try_request_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result<Box<dyn ProviderRequest>, ProviderRequestError> {
// ============================================================================
// PROVIDER ADAPTER REGISTRY (Organizational Enhancement)
// ============================================================================
/// Provider adapter configuration
#[derive(Debug, Clone)]
pub struct ProviderConfig {
pub supported_apis: &'static [&'static str],
pub adapter_type: AdapterType,
}
#[derive(Debug, Clone)]
pub enum AdapterType {
OpenAICompatible,
// Future: Claude, Gemini, etc.
}
/// Get provider configuration
pub fn get_provider_config(provider_id: &ProviderId) -> ProviderConfig {
match provider_id {
// All these providers currently use OpenAI-compatible chat completions API
// In the future, we can add provider-specific handling in separate match arms
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
ProviderConfig {
supported_apis: &["/v1/chat/completions"],
adapter_type: AdapterType::OpenAICompatible,
}
}
}
}
/// Parse request from bytes using provider ID - returns generic ProviderRequest trait object
pub fn try_request_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result<Box<dyn ProviderRequest>, ProviderRequestError> {
let config = get_provider_config(provider_id);
match config.adapter_type {
AdapterType::OpenAICompatible => {
let request = crate::apis::openai::ChatCompletionsRequest::try_from((bytes, provider_id))
.map_err(|e| ProviderRequestError {
message: format!("Failed to parse request: {}", e),
@ -175,9 +163,10 @@ pub fn try_request_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result<
/// Parse response from bytes using provider ID - returns generic ProviderResponse trait object
pub fn try_response_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result<Box<dyn ProviderResponse>, ProviderResponseError> {
match provider_id {
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
let config = get_provider_config(provider_id);
match config.adapter_type {
AdapterType::OpenAICompatible => {
// Parameterized conversion allows provider-specific response parsing
let response = crate::apis::openai::ChatCompletionsResponse::try_from((bytes, provider_id))
.map_err(|e| ProviderResponseError {
@ -192,13 +181,11 @@ pub fn try_response_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: Co
}
/// Create streaming response using provider ID - returns clean ProviderStreamResponseIter trait object
///
/// This function returns a ProviderStreamResponseIter that's just an iterator,
/// eliminating the complex nested Result<Box<dyn Iterator<...>>> type completely.
pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result<Box<dyn ProviderStreamResponseIter>, Box<dyn std::error::Error + Send + Sync>> {
match provider_id {
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
let config = get_provider_config(provider_id);
match config.adapter_type {
AdapterType::OpenAICompatible => {
// Parse SSE (Server-Sent Events) streaming data
let s = std::str::from_utf8(bytes)?;
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
@ -211,29 +198,50 @@ pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: C
}
/// Check if provider has compatible API
///
/// Replaces the old ProviderInterface::has_compatible_api method.
/// This function enables runtime API compatibility checking without needing a provider instance.
pub fn has_compatible_api(provider_id: &ProviderId, api_path: &str) -> bool {
match provider_id {
// Currently all these providers support OpenAI chat completions API
// Future providers with different APIs will get their own match arms
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
api_path == "/v1/chat/completions"
}
}
let config = get_provider_config(provider_id);
config.supported_apis.iter().any(|&supported| supported == api_path)
}
/// Get supported APIs for provider
///
/// Replaces the old ProviderInterface::supported_apis method.
/// Returns a static list of supported API endpoints for the given provider.
pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> {
match provider_id {
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
vec!["/v1/chat/completions"]
}
let config = get_provider_config(provider_id);
config.supported_apis.to_vec()
}
/// Error types for provider operations
#[derive(Debug)]
pub struct ProviderRequestError {
pub message: String,
pub source: Option<Box<dyn Error + Send + Sync>>,
}
#[derive(Debug)]
pub struct ProviderResponseError {
pub message: String,
pub source: Option<Box<dyn Error + Send + Sync>>,
}
impl fmt::Display for ProviderRequestError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Provider request error: {}", self.message)
}
}
impl fmt::Display for ProviderResponseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Provider response error: {}", self.message)
}
}
impl Error for ProviderRequestError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
}
}
impl Error for ProviderResponseError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
}
}