add support for agents (#564)

This commit is contained in:
Adil Hafeez 2025-10-14 14:01:11 -07:00 committed by GitHub
parent f8991a3c4b
commit 96e0732089
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 3571 additions and 856 deletions

View file

@ -1,6 +1,6 @@
use std::fmt::Display;
use crate::apis::{AnthropicApi, OpenAIApi};
use crate::clients::endpoints::SupportedAPIs;
use crate::apis::{OpenAIApi, AnthropicApi};
use std::fmt::Display;
/// Provider identifier enum - simple enum for identifying providers
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@ -50,41 +50,50 @@ impl ProviderId {
pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs) -> SupportedAPIs {
match (self, client_api) {
// Claude/Anthropic providers natively support Anthropic APIs
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
(ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => {
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
}
(
ProviderId::Anthropic,
SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
// OpenAI-compatible providers only support OpenAI chat completions
(ProviderId::OpenAI
| ProviderId::Groq
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
| ProviderId::XAI
| ProviderId::TogetherAI
| ProviderId::Ollama
| ProviderId::Moonshotai
| ProviderId::Zhipu
| ProviderId::Qwen,
SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(
ProviderId::OpenAI
| ProviderId::Groq
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
| ProviderId::XAI
| ProviderId::TogetherAI
| ProviderId::Ollama
| ProviderId::Moonshotai
| ProviderId::Zhipu
| ProviderId::Qwen,
SupportedAPIs::AnthropicMessagesAPI(_),
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(ProviderId::OpenAI
| ProviderId::Groq
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
| ProviderId::XAI
| ProviderId::TogetherAI
| ProviderId::Ollama
| ProviderId::Moonshotai
| ProviderId::Zhipu
| ProviderId::Qwen,
SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(
ProviderId::OpenAI
| ProviderId::Groq
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
| ProviderId::XAI
| ProviderId::TogetherAI
| ProviderId::Ollama
| ProviderId::Moonshotai
| ProviderId::Zhipu
| ProviderId::Qwen,
SupportedAPIs::OpenAIChatCompletions(_),
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
}
}
}

View file

@ -8,5 +8,5 @@ pub mod request;
pub mod response;
pub use id::ProviderId;
pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ;
pub use response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, TokenUsage };
pub use request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
pub use response::{ProviderResponse, ProviderResponseType, ProviderStreamResponse, TokenUsage};

View file

@ -1,11 +1,11 @@
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::anthropic::MessagesRequest;
use crate::apis::openai::ChatCompletionsRequest;
use crate::clients::endpoints::SupportedAPIs;
use serde_json::Value;
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use std::collections::HashMap;
#[derive(Clone)]
pub enum ProviderRequestType {
ChatCompletionsRequest(ChatCompletionsRequest),
@ -103,15 +103,18 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
// Use SupportedApi to determine the appropriate request type
match client_api {
SupportedAPIs::OpenAIChatCompletions(_) => {
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
}
SupportedAPIs::AnthropicMessagesAPI(_) => {
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
let chat_completion_request: ChatCompletionsRequest =
ChatCompletionsRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::MessagesRequest(messages_request))
}
Ok(ProviderRequestType::ChatCompletionsRequest(
chat_completion_request,
))
}
SupportedAPIs::AnthropicMessagesAPI(_) => {
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::MessagesRequest(messages_request))
}
}
}
}
@ -120,40 +123,55 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
type Error = ProviderRequestError;
fn try_from((request, upstream_api): (ProviderRequestType, &SupportedAPIs)) -> Result<Self, Self::Error> {
fn try_from(
(request, upstream_api): (ProviderRequestType, &SupportedAPIs),
) -> Result<Self, Self::Error> {
match (request, upstream_api) {
// Same API - no conversion needed, just clone the reference
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
}
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
Ok(ProviderRequestType::MessagesRequest(messages_req))
}
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedAPIs::OpenAIChatCompletions(_),
) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)),
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedAPIs::AnthropicMessagesAPI(_),
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
// Cross-API conversion - cloning is necessary for transformation
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
let messages_req = MessagesRequest::try_from(chat_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e),
source: Some(Box::new(e))
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
let messages_req =
MessagesRequest::try_from(chat_req).map_err(|e| ProviderRequestError {
message: format!(
"Failed to convert ChatCompletionsRequest to MessagesRequest: {}",
e
),
source: Some(Box::new(e)),
})?;
Ok(ProviderRequestType::MessagesRequest(messages_req))
}
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
let chat_req = ChatCompletionsRequest::try_from(messages_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e),
source: Some(Box::new(e))
})?;
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedAPIs::OpenAIChatCompletions(_),
) => {
let chat_req = ChatCompletionsRequest::try_from(messages_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert MessagesRequest to ChatCompletionsRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
}
}
}
}
/// Error types for provider operations
#[derive(Debug)]
pub struct ProviderRequestError {
@ -169,19 +187,20 @@ impl fmt::Display for ProviderRequestError {
impl Error for ProviderRequestError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
self.source
.as_ref()
.map(|e| e.as_ref() as &(dyn Error + 'static))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::endpoints::SupportedAPIs;
use crate::apis::anthropic::AnthropicApi::Messages;
use crate::apis::openai::OpenAIApi::ChatCompletions;
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
use crate::apis::openai::{ChatCompletionsRequest};
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::openai::OpenAIApi::ChatCompletions;
use crate::clients::endpoints::SupportedAPIs;
use crate::clients::transformer::ExtractText;
use serde_json::json;
@ -202,7 +221,7 @@ mod tests {
ProviderRequestType::ChatCompletionsRequest(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.messages.len(), 2);
},
}
_ => panic!("Expected ChatCompletionsRequest variant"),
}
}
@ -225,7 +244,7 @@ mod tests {
ProviderRequestType::MessagesRequest(r) => {
assert_eq!(r.model, "claude-3-sonnet");
assert_eq!(r.messages.len(), 1);
},
}
_ => panic!("Expected MessagesRequest variant"),
}
}
@ -247,7 +266,7 @@ mod tests {
ProviderRequestType::ChatCompletionsRequest(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.messages.len(), 2);
},
}
_ => panic!("Expected ChatCompletionsRequest variant"),
}
}
@ -271,7 +290,7 @@ mod tests {
ProviderRequestType::ChatCompletionsRequest(r) => {
assert_eq!(r.model, "claude-3-sonnet");
assert_eq!(r.messages.len(), 1);
},
}
_ => panic!("Expected ChatCompletionsRequest variant"),
}
}
@ -280,13 +299,15 @@ mod tests {
fn test_v1_messages_to_v1_chat_completions_roundtrip() {
let anthropic_req = AnthropicMessagesRequest {
model: "claude-3-sonnet".to_string(),
system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single("You are a helpful assistant".to_string())),
messages: vec![
crate::apis::anthropic::MessagesMessage {
role: crate::apis::anthropic::MessagesRole::User,
content: crate::apis::anthropic::MessagesMessageContent::Single("Hello!".to_string()),
}
],
system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single(
"You are a helpful assistant".to_string(),
)),
messages: vec![crate::apis::anthropic::MessagesMessage {
role: crate::apis::anthropic::MessagesRole::User,
content: crate::apis::anthropic::MessagesMessageContent::Single(
"Hello!".to_string(),
),
}],
max_tokens: 128,
container: None,
mcp_servers: None,
@ -302,16 +323,27 @@ mod tests {
metadata: None,
};
let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone()).expect("Anthropic->OpenAI conversion failed");
let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req).expect("OpenAI->Anthropic conversion failed");
let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone())
.expect("Anthropic->OpenAI conversion failed");
let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req)
.expect("OpenAI->Anthropic conversion failed");
assert_eq!(anthropic_req.model, anthropic_req2.model);
// Compare system prompt text if present
assert_eq!(
anthropic_req.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None }),
anthropic_req2.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None })
anthropic_req.system.as_ref().and_then(|s| match s {
crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t),
_ => None,
}),
anthropic_req2.system.as_ref().and_then(|s| match s {
crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t),
_ => None,
})
);
assert_eq!(
anthropic_req.messages[0].role,
anthropic_req2.messages[0].role
);
assert_eq!(anthropic_req.messages[0].role, anthropic_req2.messages[0].role);
// Compare message content text if present
assert_eq!(
anthropic_req.messages[0].content.extract_text(),
@ -320,49 +352,54 @@ mod tests {
assert_eq!(anthropic_req.max_tokens, anthropic_req2.max_tokens);
}
#[test]
fn test_v1_chat_completions_to_v1_messages_roundtrip() {
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
use crate::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent};
#[test]
fn test_v1_chat_completions_to_v1_messages_roundtrip() {
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
let openai_req = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![
Message {
role: Role::System,
content: MessageContent::Text("You are a helpful assistant".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
},
Message {
role: Role::User,
content: MessageContent::Text("Hello!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}
],
temperature: Some(0.7),
top_p: Some(1.0),
max_tokens: Some(128),
stream: Some(false),
stop: Some(vec!["\n".to_string()]),
tools: None,
tool_choice: None,
parallel_tool_calls: None,
..Default::default()
};
let openai_req = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![
Message {
role: Role::System,
content: MessageContent::Text("You are a helpful assistant".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
},
Message {
role: Role::User,
content: MessageContent::Text("Hello!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
},
],
temperature: Some(0.7),
top_p: Some(1.0),
max_tokens: Some(128),
stream: Some(false),
stop: Some(vec!["\n".to_string()]),
tools: None,
tool_choice: None,
parallel_tool_calls: None,
..Default::default()
};
let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone()).expect("OpenAI->Anthropic conversion failed");
let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req).expect("Anthropic->OpenAI conversion failed");
let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone())
.expect("OpenAI->Anthropic conversion failed");
let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req)
.expect("Anthropic->OpenAI conversion failed");
assert_eq!(openai_req.model, openai_req2.model);
assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role);
assert_eq!(openai_req.messages[0].content.extract_text(), openai_req2.messages[0].content.extract_text());
// After roundtrip, deprecated max_tokens should be converted to max_completion_tokens
let original_max_tokens = openai_req.max_completion_tokens.or(openai_req.max_tokens);
let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens);
assert_eq!(original_max_tokens, roundtrip_max_tokens);
}
assert_eq!(openai_req.model, openai_req2.model);
assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role);
assert_eq!(
openai_req.messages[0].content.extract_text(),
openai_req2.messages[0].content.extract_text()
);
// After roundtrip, deprecated max_tokens should be converted to max_completion_tokens
let original_max_tokens = openai_req.max_completion_tokens.or(openai_req.max_tokens);
let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens);
assert_eq!(original_max_tokens, roundtrip_max_tokens);
}
}

View file

@ -1,15 +1,15 @@
use crate::providers::id::ProviderId;
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::error::Error;
use std::fmt;
use std::convert::TryFrom;
use std::str::FromStr;
use crate::apis::anthropic::MessagesResponse;
use crate::apis::anthropic::MessagesStreamEvent;
use crate::apis::openai::ChatCompletionsResponse;
use crate::apis::openai::ChatCompletionsStreamResponse;
use crate::apis::anthropic::MessagesStreamEvent;
use crate::clients::endpoints::SupportedAPIs;
use crate::apis::anthropic::MessagesResponse;
/// Trait for token usage information
pub trait TokenUsage {
@ -38,7 +38,8 @@ pub trait ProviderResponse: Send + Sync {
/// Extract token counts for metrics
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
self.usage()
.map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
}
}
@ -110,19 +111,19 @@ impl ProviderStreamResponse for ProviderStreamResponseType {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SseEvent {
#[serde(rename = "data")]
pub data: Option<String>, // The JSON payload after "data: "
pub data: Option<String>, // The JSON payload after "data: "
#[serde(skip_serializing_if = "Option::is_none")]
pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta")
pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta")
#[serde(skip_serializing, skip_deserializing)]
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
#[serde(skip_serializing, skip_deserializing)]
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
#[serde(skip_serializing, skip_deserializing)]
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
#[serde(skip_serializing, skip_deserializing)]
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
}
impl SseEvent {
@ -145,13 +146,13 @@ impl SseEvent {
/// Get the parsed provider response if available
pub fn provider_response(&self) -> Result<&dyn ProviderStreamResponse, std::io::Error> {
self.provider_stream_response.as_ref()
self.provider_stream_response
.as_ref()
.map(|resp| resp as &dyn ProviderStreamResponse)
.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::NotFound, "Provider response not found")
})
}
}
impl FromStr for SseEvent {
@ -172,7 +173,8 @@ impl FromStr for SseEvent {
sse_transform_buffer: line.to_string(),
provider_stream_response: None,
})
} else if line.starts_with("event: ") { //used by Anthropic
} else if line.starts_with("event: ") {
//used by Anthropic
let event_type = line[7..].to_string();
if event_type.is_empty() {
return Err(SseParseError {
@ -207,12 +209,13 @@ impl Into<Vec<u8>> for SseEvent {
}
}
// --- Response transformation logic for client API compatibility ---
impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
type Error = std::io::Error;
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result<Self, Self::Error> {
fn try_from(
(bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId),
) -> Result<Self, Self::Error> {
let upstream_api = provider_id.compatible_api_for_client(client_api);
match (&upstream_api, client_api) {
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
@ -230,8 +233,13 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to OpenAI ChatCompletions format using the transformer
let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
let chat_resp: ChatCompletionsResponse =
anthropic_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
}
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
@ -239,8 +247,12 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to Anthropic Messages format using the transformer
let messages_resp: MessagesResponse = openai_resp.try_into()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
let messages_resp: MessagesResponse = openai_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::MessagesResponse(messages_resp))
}
}
@ -251,36 +263,50 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
impl TryFrom<(&[u8], &SupportedAPIs, &SupportedAPIs)> for ProviderStreamResponseType {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from((bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedAPIs)) -> Result<Self, Self::Error> {
fn try_from(
(bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedAPIs),
) -> Result<Self, Self::Error> {
match (upstream_api, client_api) {
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
let resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?;
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(resp))
let resp: crate::apis::openai::ChatCompletionsStreamResponse =
serde_json::from_slice(bytes)?;
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(
resp,
))
}
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
let resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?;
let resp: crate::apis::anthropic::MessagesStreamEvent =
serde_json::from_slice(bytes)?;
Ok(ProviderStreamResponseType::MessagesStreamEvent(resp))
}
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?;
let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent =
serde_json::from_slice(bytes)?;
// Transform to OpenAI ChatCompletions stream format using the transformer
let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = anthropic_resp.try_into()?;
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(chat_resp))
let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse =
anthropic_resp.try_into()?;
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(
chat_resp,
))
}
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
// Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion
if bytes == b"[DONE]" {
return Ok(ProviderStreamResponseType::MessagesStreamEvent(
crate::apis::anthropic::MessagesStreamEvent::MessageStop
crate::apis::anthropic::MessagesStreamEvent::MessageStop,
));
}
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?;
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse =
serde_json::from_slice(bytes)?;
// Transform to Anthropic Messages stream format using the transformer
let messages_resp: crate::apis::anthropic::MessagesStreamEvent = openai_resp.try_into()?;
Ok(ProviderStreamResponseType::MessagesStreamEvent(messages_resp))
let messages_resp: crate::apis::anthropic::MessagesStreamEvent =
openai_resp.try_into()?;
Ok(ProviderStreamResponseType::MessagesStreamEvent(
messages_resp,
))
}
}
}
@ -290,7 +316,9 @@ impl TryFrom<(&[u8], &SupportedAPIs, &SupportedAPIs)> for ProviderStreamResponse
impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from((sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedAPIs)) -> Result<Self, Self::Error> {
fn try_from(
(sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedAPIs),
) -> Result<Self, Self::Error> {
// Create a new transformed event based on the original
let mut transformed_event = sse_event;
@ -298,7 +326,8 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent {
if transformed_event.data.is_some() {
let data_str = transformed_event.data.as_ref().unwrap();
let data_bytes = data_str.as_bytes();
let transformed_response = ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?;
let transformed_response =
ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?;
let transformed_json = serde_json::to_string(&transformed_response)?;
transformed_event.sse_transform_buffer = format!("data: {}\n\n", transformed_json);
transformed_event.provider_stream_response = Some(transformed_response);
@ -344,7 +373,10 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent {
transformed_event.sse_transform_buffer
);
} else {
transformed_event.sse_transform_buffer = format!("event: {}\n{}", event_type, transformed_event.sse_transform_buffer);
transformed_event.sse_transform_buffer = format!(
"event: {}\n{}",
event_type, transformed_event.sse_transform_buffer
);
}
}
// If event_type is None, we just keep the data line as-is without an event line
@ -396,7 +428,10 @@ where
I::Item: AsRef<str>,
{
pub fn new(lines: I) -> Self {
Self { lines, done_seen: false }
Self {
lines,
done_seen: false,
}
}
}
@ -451,7 +486,6 @@ pub struct ProviderResponseError {
pub source: Option<Box<dyn Error + Send + Sync>>,
}
impl fmt::Display for ProviderResponseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Provider response error: {}", self.message)
@ -460,17 +494,19 @@ impl fmt::Display for ProviderResponseError {
impl Error for ProviderResponseError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
self.source
.as_ref()
.map(|e| e.as_ref() as &(dyn Error + 'static))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
use crate::clients::endpoints::SupportedAPIs;
use crate::providers::id::ProviderId;
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi;
use serde_json::json;
#[test]
@ -491,13 +527,17 @@ mod tests {
"system_fingerprint": null
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::OpenAI));
let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
&ProviderId::OpenAI,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::ChatCompletionsResponse(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.choices.len(), 1);
},
}
_ => panic!("Expected ChatCompletionsResponse variant"),
}
}
@ -516,13 +556,17 @@ mod tests {
"usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 }
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::Anthropic));
let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
&ProviderId::Anthropic,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::MessagesResponse(r) => {
assert_eq!(r.model, "claude-3-sonnet-20240229");
assert_eq!(r.content.len(), 1);
},
}
_ => panic!("Expected MessagesResponse variant"),
}
}
@ -546,14 +590,18 @@ mod tests {
"usage": { "prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35 }
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI));
let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
&ProviderId::OpenAI,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::MessagesResponse(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.usage.input_tokens, 10);
assert_eq!(r.usage.output_tokens, 25);
},
}
_ => panic!("Expected MessagesResponse variant"),
}
}
@ -584,14 +632,18 @@ mod tests {
}
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Anthropic));
let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
&ProviderId::Anthropic,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::ChatCompletionsResponse(r) => {
assert_eq!(r.model, "claude-3-sonnet-20240229");
assert_eq!(r.usage.prompt_tokens, 10);
assert_eq!(r.usage.completion_tokens, 25);
},
}
_ => panic!("Expected ChatCompletionsResponse variant"),
}
}
@ -603,11 +655,17 @@ mod tests {
let event: Result<SseEvent, _> = line.parse();
assert!(event.is_ok());
let event = event.unwrap();
assert_eq!(event.data, Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string()));
assert_eq!(
event.data,
Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string())
);
// Test conversion back to line using Display trait
let wire_format = event.to_string();
assert_eq!(wire_format, "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n");
assert_eq!(
wire_format,
"data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"
);
// Test [DONE] marker - should be valid SSE event
let done_line = "data: [DONE]";
@ -639,10 +697,12 @@ mod tests {
event: None,
raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"}
"#.to_string(),
"#
.to_string(),
sse_transform_buffer: r#"data: {"id":"test","object":"chat.completion.chunk"}
"#.to_string(),
"#
.to_string(),
provider_stream_response: None,
};
@ -679,7 +739,8 @@ mod tests {
data: Some(r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string()),
event: Some("content_block_delta".to_string()),
raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#
.to_string(),
provider_stream_response: None,
};
assert!(!normal_event.should_skip());
@ -705,7 +766,7 @@ mod tests {
"data: {\"type\": \"ping\"}".to_string(), // This should be filtered out
"data: {\"id\": \"msg2\", \"object\": \"chat.completion.chunk\"}".to_string(),
"data: {\"type\": \"ping\"}".to_string(), // This should be filtered out
"data: [DONE]".to_string(), // This should end the stream
"data: [DONE]".to_string(), // This should end the stream
];
let mut iter = SseStreamIter::new(test_lines.into_iter());
@ -773,13 +834,15 @@ mod tests {
#[test]
fn test_provider_stream_response_event_type() {
use crate::apis::anthropic::{MessagesStreamEvent, MessagesContentDelta};
use crate::apis::anthropic::{MessagesContentDelta, MessagesStreamEvent};
use crate::apis::openai::ChatCompletionsStreamResponse;
// Test Anthropic event type
let anthropic_event = MessagesStreamEvent::ContentBlockDelta {
index: 0,
delta: MessagesContentDelta::TextDelta { text: "Hello".to_string() },
delta: MessagesContentDelta::TextDelta {
text: "Hello".to_string(),
},
};
let provider_type = ProviderStreamResponseType::MessagesStreamEvent(anthropic_event);
assert_eq!(provider_type.event_type(), Some("content_block_delta"));
@ -806,15 +869,23 @@ mod tests {
// Test that [DONE] marker is properly converted to MessageStop in the transformation layer
let done_bytes = b"[DONE]";
let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
let upstream_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::openai::OpenAIApi::ChatCompletions);
let upstream_api =
SupportedAPIs::OpenAIChatCompletions(crate::apis::openai::OpenAIApi::ChatCompletions);
let result = ProviderStreamResponseType::try_from((done_bytes.as_slice(), &client_api, &upstream_api));
let result = ProviderStreamResponseType::try_from((
done_bytes.as_slice(),
&client_api,
&upstream_api,
));
assert!(result.is_ok());
if let Ok(ProviderStreamResponseType::MessagesStreamEvent(event)) = result {
// Verify it's a MessageStop event
assert_eq!(event.event_type(), Some("message_stop"));
assert!(matches!(event, crate::apis::anthropic::MessagesStreamEvent::MessageStop));
assert!(matches!(
event,
crate::apis::anthropic::MessagesStreamEvent::MessageStop
));
} else {
panic!("Expected MessagesStreamEvent::MessageStop");
}