cargo clippy (#660)

This commit is contained in:
Adil Hafeez 2025-12-25 21:08:37 -08:00 committed by GitHub
parent c75e7606f9
commit ca95ffb63d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
62 changed files with 1864 additions and 1187 deletions

View file

@ -66,7 +66,7 @@ impl ApiDefinition for AmazonBedrockApi {
/// Amazon Bedrock Converse request
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct ConverseRequest {
/// The model ID or ARN to invoke
pub model_id: String,
@ -91,7 +91,7 @@ pub struct ConverseRequest {
pub additional_model_response_field_paths: Option<Vec<String>>,
/// Performance configuration
#[serde(rename = "performanceConfig")]
pub performance_config: Option<PerformanceConfiguration>,
pub performance_config: Option<InferenceConfiguration>,
/// Prompt variables for Prompt management
#[serde(rename = "promptVariables")]
pub prompt_variables: Option<HashMap<String, PromptVariableValues>>,
@ -105,26 +105,6 @@ pub struct ConverseRequest {
pub stream: bool,
}
impl Default for ConverseRequest {
fn default() -> Self {
Self {
model_id: String::new(),
messages: None,
system: None,
inference_config: None,
tool_config: None,
guardrail_config: None,
additional_model_request_fields: None,
additional_model_response_field_paths: None,
performance_config: None,
prompt_variables: None,
request_metadata: None,
metadata: None,
stream: false,
}
}
}
/// Amazon Bedrock ConverseStream request (same structure as Converse)
pub type ConverseStreamRequest = ConverseRequest;
@ -204,8 +184,8 @@ impl ProviderRequest for ConverseRequest {
self.tool_config.as_ref()?.tools.as_ref().map(|tools| {
tools
.iter()
.filter_map(|tool| match tool {
Tool::ToolSpec { tool_spec } => Some(tool_spec.name.clone()),
.map(|tool| match tool {
Tool::ToolSpec { tool_spec } => tool_spec.name.clone(),
})
.collect()
})
@ -242,17 +222,14 @@ impl ProviderRequest for ConverseRequest {
// Add system messages if present
if let Some(system) = &self.system {
for sys_block in system {
match sys_block {
SystemContentBlock::Text { text } => {
openai_messages.push(Message {
role: Role::System,
content: MessageContent::Text(text.clone()),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
_ => {} // Skip other system content types
if let SystemContentBlock::Text { text } = sys_block {
openai_messages.push(Message {
role: Role::System,
content: MessageContent::Text(text.clone()),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
}
}
@ -266,7 +243,9 @@ impl ProviderRequest for ConverseRequest {
};
// Extract text from content blocks
let content = msg.content.iter()
let content = msg
.content
.iter()
.filter_map(|block| {
if let ContentBlock::Text { text } = block {
Some(text.clone())
@ -311,16 +290,14 @@ impl ProviderRequest for ConverseRequest {
_ => continue,
};
let content = if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
vec![ContentBlock::Text { text: text.clone() }]
} else {
vec![]
};
let content =
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
vec![ContentBlock::Text { text: text.clone() }]
} else {
vec![]
};
bedrock_messages.push(crate::apis::amazon_bedrock::Message {
role,
content,
});
bedrock_messages.push(crate::apis::amazon_bedrock::Message { role, content });
}
_ => {}
}
@ -369,7 +346,7 @@ pub enum ConverseStreamEvent {
ContentBlockDelta(ContentBlockDeltaEvent),
ContentBlockStop(ContentBlockStopEvent),
MessageStop(MessageStopEvent),
Metadata(ConverseStreamMetadataEvent),
Metadata(Box<ConverseStreamMetadataEvent>),
// Error events
InternalServerException(BedrockException),
ModelStreamErrorException(BedrockException),
@ -1063,7 +1040,7 @@ impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEve
"metadata" => {
let event: ConverseStreamMetadataEvent =
serde_json::from_slice(payload).map_err(BedrockError::Serialization)?;
Ok(ConverseStreamEvent::Metadata(event))
Ok(ConverseStreamEvent::Metadata(Box::new(event)))
}
unknown => Err(BedrockError::Validation {
message: format!("Unknown event type: {}", unknown),
@ -1106,10 +1083,10 @@ impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEve
}
}
impl Into<String> for ConverseStreamEvent {
fn into(self) -> String {
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
let event_type = match &self {
impl From<ConverseStreamEvent> for String {
fn from(val: ConverseStreamEvent) -> String {
let transformed_json = serde_json::to_string(&val).unwrap_or_default();
let event_type = match &val {
ConverseStreamEvent::MessageStart { .. } => "message_start",
ConverseStreamEvent::ContentBlockStart { .. } => "content_block_start",
ConverseStreamEvent::ContentBlockDelta { .. } => "content_block_delta",

File diff suppressed because one or more lines are too long

View file

@ -286,7 +286,6 @@ pub struct ImageUrl {
}
/// A single message in a chat conversation
/// A tool call made by the assistant
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ToolCall {
@ -388,7 +387,7 @@ pub enum StaticContentType {
/// Chat completions API response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct ChatCompletionsResponse {
pub id: String,
pub object: Option<String>,
@ -402,22 +401,6 @@ pub struct ChatCompletionsResponse {
pub metadata: Option<HashMap<String, Value>>,
}
impl Default for ChatCompletionsResponse {
fn default() -> Self {
ChatCompletionsResponse {
id: String::new(),
object: None,
created: 0,
model: String::new(),
choices: vec![],
usage: Usage::default(),
system_fingerprint: None,
service_tier: None,
metadata: None,
}
}
}
/// Finish reason for completion
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
@ -431,7 +414,7 @@ pub enum FinishReason {
/// Token usage information
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
@ -440,18 +423,6 @@ pub struct Usage {
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
impl Default for Usage {
fn default() -> Self {
Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: None,
completion_tokens_details: None,
}
}
}
/// Detailed breakdown of prompt tokens
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -472,7 +443,7 @@ pub struct CompletionTokensDetails {
/// A single choice in the response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct Choice {
pub index: u32,
pub message: ResponseMessage,
@ -480,17 +451,6 @@ pub struct Choice {
pub logprobs: Option<Value>,
}
impl Default for Choice {
fn default() -> Self {
Choice {
index: 0,
message: ResponseMessage::default(),
finish_reason: None,
logprobs: None,
}
}
}
// ============================================================================
// STREAMING API TYPES
// ============================================================================
@ -608,7 +568,6 @@ pub enum OpenAIError {
// ============================================================================
/// Trait Implementations
/// ===========================================================================
/// Parameterized conversion for ChatCompletionsRequest
impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIStreamError;
@ -721,7 +680,7 @@ impl ProviderRequest for ChatCompletionsRequest {
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
return &self.metadata;
&self.metadata
}
fn remove_metadata_key(&mut self, key: &str) -> bool {

View file

@ -1,7 +1,7 @@
use std::collections::HashMap;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use std::collections::HashMap;
impl TryFrom<&[u8]> for ResponsesAPIRequest {
type Error = serde_json::Error;
@ -172,18 +172,14 @@ pub enum MessageRole {
#[serde(tag = "type", rename_all = "snake_case")]
pub enum InputContent {
/// Text input
InputText {
text: String,
},
InputText { text: String },
/// Image input via URL
InputImage {
image_url: String,
detail: Option<String>,
},
/// File input via URL
InputFile {
file_url: String,
},
InputFile { file_url: String },
/// Audio input
InputAudio {
data: Option<String>,
@ -222,9 +218,7 @@ pub struct TextConfig {
pub enum TextFormat {
Text,
JsonObject,
JsonSchema {
json_schema: serde_json::Value,
},
JsonSchema { json_schema: serde_json::Value },
}
/// Reasoning effort levels
@ -608,9 +602,7 @@ pub enum OutputContent {
transcript: Option<String>,
},
/// Refusal output
Refusal {
refusal: String,
},
Refusal { refusal: String },
}
/// Annotations for output text
@ -663,13 +655,9 @@ pub struct FileSearchResult {
#[serde(tag = "type", rename_all = "snake_case")]
pub enum CodeInterpreterOutput {
/// Text output
Text {
text: String,
},
Text { text: String },
/// Image output
Image {
image: String,
},
Image { image: String },
}
/// Response usage statistics
@ -951,9 +939,7 @@ pub enum ResponsesAPIStreamEvent {
},
/// Done event (end of stream)
Done {
sequence_number: i32,
},
Done { sequence_number: i32 },
}
// ============================================================================
@ -1052,12 +1038,19 @@ impl ProviderRequest for ResponsesAPIRequest {
MessageContent::Text(text) => text.clone(),
MessageContent::Items(content_items) => {
content_items.iter().fold(String::new(), |acc, content| {
acc + " " + &match content {
InputContent::InputText { text } => text.clone(),
InputContent::InputImage { .. } => "[Image]".to_string(),
InputContent::InputFile { .. } => "[File]".to_string(),
InputContent::InputAudio { .. } => "[Audio]".to_string(),
}
acc + " "
+ &match content {
InputContent::InputText { text } => text.clone(),
InputContent::InputImage { .. } => {
"[Image]".to_string()
}
InputContent::InputFile { .. } => {
"[File]".to_string()
}
InputContent::InputAudio { .. } => {
"[Audio]".to_string()
}
}
})
}
};
@ -1082,11 +1075,9 @@ impl ProviderRequest for ResponsesAPIRequest {
match &msg.content {
MessageContent::Text(text) => Some(text.clone()),
MessageContent::Items(content_items) => {
content_items.iter().find_map(|content| {
match content {
InputContent::InputText { text } => Some(text.clone()),
_ => None,
}
content_items.iter().find_map(|content| match content {
InputContent::InputText { text } => Some(text.clone()),
_ => None,
})
}
}
@ -1176,9 +1167,12 @@ impl ProviderRequest for ResponsesAPIRequest {
// Extract text from message content
let content = match &msg.content {
crate::apis::openai_responses::MessageContent::Text(text) => text.clone(),
crate::apis::openai_responses::MessageContent::Text(text) => {
text.clone()
}
crate::apis::openai_responses::MessageContent::Items(items) => {
items.iter()
items
.iter()
.filter_map(|c| {
if let InputContent::InputText { text } = c {
Some(text.clone())
@ -1214,7 +1208,8 @@ impl ProviderRequest for ResponsesAPIRequest {
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
// For ResponsesAPI, we need to convert messages back to input format
// Extract system messages as instructions
let system_text = messages.iter()
let system_text = messages
.iter()
.filter(|msg| msg.role == crate::apis::openai::Role::System)
.filter_map(|msg| {
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
@ -1233,23 +1228,27 @@ impl ProviderRequest for ResponsesAPIRequest {
// Convert user/assistant messages to InputParam
// For simplicity, we'll use the last user message as the input
// or combine all non-system messages
let input_messages: Vec<_> = messages.iter()
let input_messages: Vec<_> = messages
.iter()
.filter(|msg| msg.role != crate::apis::openai::Role::System)
.collect();
if !input_messages.is_empty() {
// If there's only one message, use Text format
if input_messages.len() == 1 {
if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content {
if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content
{
self.input = crate::apis::openai_responses::InputParam::Text(text.clone());
}
} else {
// Multiple messages - combine them as text for now
// A more sophisticated approach would use InputParam::Items
let combined_text = input_messages.iter()
let combined_text = input_messages
.iter()
.filter_map(|msg| {
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
Some(format!("{}: {}",
Some(format!(
"{}: {}",
match msg.role {
crate::apis::openai::Role::User => "User",
crate::apis::openai::Role::Assistant => "Assistant",
@ -1274,10 +1273,10 @@ impl ProviderRequest for ResponsesAPIRequest {
// Into<String> Implementation for SSE Formatting
// ============================================================================
impl Into<String> for ResponsesAPIStreamEvent {
fn into(self) -> String {
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
let event_type = match &self {
impl From<ResponsesAPIStreamEvent> for String {
fn from(val: ResponsesAPIStreamEvent) -> Self {
let transformed_json = serde_json::to_string(&val).unwrap_or_default();
let event_type = match &val {
ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created",
ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress",
ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed",
@ -1365,10 +1364,10 @@ impl crate::providers::streaming_response::ProviderStreamResponse for ResponsesA
fn role(&self) -> Option<&str> {
match self {
ResponsesAPIStreamEvent::ResponseOutputItemDone { item, .. } => match item {
OutputItem::Message { role, .. } => Some(role.as_str()),
_ => None,
},
ResponsesAPIStreamEvent::ResponseOutputItemDone {
item: OutputItem::Message { role, .. },
..
} => Some(role.as_str()),
_ => None,
}
}

View file

@ -34,10 +34,7 @@ where
}
pub fn decode_frame(&mut self) -> Option<DecodedFrame> {
match self.decoder.decode_frame(&mut self.buffer) {
Ok(frame) => Some(frame),
Err(_e) => None, // Fatal decode error
}
self.decoder.decode_frame(&mut self.buffer).ok()
}
pub fn buffer_mut(&mut self) -> &mut B {

View file

@ -1,5 +1,5 @@
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
use crate::apis::anthropic::MessagesStreamEvent;
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
use crate::providers::streaming_response::ProviderStreamResponseType;
use std::collections::HashSet;
@ -31,6 +31,12 @@ pub struct AnthropicMessagesStreamBuffer {
model: Option<String>,
}
impl Default for AnthropicMessagesStreamBuffer {
fn default() -> Self {
Self::new()
}
}
impl AnthropicMessagesStreamBuffer {
pub fn new() -> Self {
Self {
@ -154,7 +160,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
// Inject message_start if needed
if !self.message_started {
let model = self.model.as_deref().unwrap_or("unknown");
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
let message_start =
AnthropicMessagesStreamBuffer::create_message_start_event(model);
self.buffered_events.push(message_start);
self.message_started = true;
}
@ -169,7 +176,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
// Inject message_start if needed
if !self.message_started {
let model = self.model.as_deref().unwrap_or("unknown");
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
let message_start =
AnthropicMessagesStreamBuffer::create_message_start_event(model);
self.buffered_events.push(message_start);
self.message_started = true;
}
@ -177,7 +185,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
// Check if ContentBlockStart was sent for this index
if !self.has_content_block_start_been_sent(index) {
// Inject ContentBlockStart before delta
let content_block_start = AnthropicMessagesStreamBuffer::create_content_block_start_event();
let content_block_start =
AnthropicMessagesStreamBuffer::create_content_block_start_event();
self.buffered_events.push(content_block_start);
self.set_content_block_start_sent(index);
self.needs_content_block_stop = true;
@ -189,7 +198,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
MessagesStreamEvent::MessageDelta { usage, .. } => {
// Inject ContentBlockStop before message_delta
if self.needs_content_block_stop {
let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event();
let content_block_stop =
AnthropicMessagesStreamBuffer::create_content_block_stop_event();
self.buffered_events.push(content_block_stop);
self.needs_content_block_stop = false;
}
@ -199,10 +209,10 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
if let Some(last_event) = self.buffered_events.last_mut() {
if let Some(ProviderStreamResponseType::MessagesStreamEvent(
MessagesStreamEvent::MessageDelta {
usage: last_usage,
..
}
)) = &mut last_event.provider_stream_response {
usage: last_usage, ..
},
)) = &mut last_event.provider_stream_response
{
// Merge: take stop_reason from first, usage from second (if non-zero)
if usage.input_tokens > 0 || usage.output_tokens > 0 {
*last_usage = usage.clone();
@ -243,7 +253,7 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
}
}
fn into_bytes(&mut self) -> Vec<u8> {
fn to_bytes(&mut self) -> Vec<u8> {
// Convert all accumulated events to bytes and clear buffer
// NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta
// or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming.
@ -276,10 +286,10 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
use crate::apis::streaming_shapes::sse::SseStreamIter;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
#[test]
fn test_openai_to_anthropic_complete_transformation() {
@ -308,11 +318,12 @@ data: [DONE]"#;
let mut buffer = AnthropicMessagesStreamBuffer::new();
for raw_event in stream_iter {
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
let transformed_event =
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event);
}
let output_bytes = buffer.into_bytes();
let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
@ -321,25 +332,54 @@ data: [DONE]"#;
// Assertions
assert!(!output_bytes.is_empty(), "Should have output");
assert!(output.contains("event: message_start"), "Should have message_start");
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
assert!(
output.contains("event: message_start"),
"Should have message_start"
);
assert!(
output.contains("event: content_block_start"),
"Should have content_block_start (injected)"
);
let delta_count = output.matches("event: content_block_delta").count();
assert_eq!(delta_count, 2, "Should have exactly 2 content_block_delta events");
assert_eq!(
delta_count, 2,
"Should have exactly 2 content_block_delta events"
);
// Verify both pieces of content are present
assert!(output.contains("\"text\":\"Hello\""), "Should have first content delta 'Hello'");
assert!(output.contains("\"text\":\" world\""), "Should have second content delta ' world'");
assert!(
output.contains("\"text\":\"Hello\""),
"Should have first content delta 'Hello'"
);
assert!(
output.contains("\"text\":\" world\""),
"Should have second content delta ' world'"
);
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
assert!(output.contains("event: message_delta"), "Should have message_delta");
assert!(output.contains("event: message_stop"), "Should have message_stop");
assert!(
output.contains("event: content_block_stop"),
"Should have content_block_stop (injected)"
);
assert!(
output.contains("event: message_delta"),
"Should have message_delta"
);
assert!(
output.contains("event: message_stop"),
"Should have message_stop"
);
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
println!("✓ Complete transformation: OpenAI ChatCompletions → Anthropic Messages API");
println!("✓ Injected lifecycle events: message_start, content_block_start, content_block_stop");
println!("✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)", delta_count);
println!(
"✓ Injected lifecycle events: message_start, content_block_start, content_block_stop"
);
println!(
"✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)",
delta_count
);
println!("✓ Complete stream with message_stop");
println!("✓ Proper Anthropic protocol sequencing\n");
}
@ -369,11 +409,12 @@ data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890
let mut buffer = AnthropicMessagesStreamBuffer::new();
for raw_event in stream_iter {
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
let transformed_event =
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event);
}
let output_bytes = buffer.into_bytes();
let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
@ -382,31 +423,61 @@ data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890
// Assertions
assert!(!output_bytes.is_empty(), "Should have output");
assert!(output.contains("event: message_start"), "Should have message_start");
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
assert!(
output.contains("event: message_start"),
"Should have message_start"
);
assert!(
output.contains("event: content_block_start"),
"Should have content_block_start (injected)"
);
let delta_count = output.matches("event: content_block_delta").count();
assert_eq!(delta_count, 3, "Should have exactly 3 content_block_delta events");
assert_eq!(
delta_count, 3,
"Should have exactly 3 content_block_delta events"
);
// Verify all three pieces of content are present
assert!(output.contains("\"text\":\"The weather\""), "Should have first content delta");
assert!(output.contains("\"text\":\" in San Francisco\""), "Should have second content delta");
assert!(output.contains("\"text\":\" is\""), "Should have third content delta");
assert!(
output.contains("\"text\":\"The weather\""),
"Should have first content delta"
);
assert!(
output.contains("\"text\":\" in San Francisco\""),
"Should have second content delta"
);
assert!(
output.contains("\"text\":\" is\""),
"Should have third content delta"
);
// For partial streams (no finish_reason, no [DONE]), we do NOT inject content_block_stop
// because the stream may continue. This is correct behavior - only inject lifecycle events
// when we have explicit signals from upstream (finish_reason, [DONE], etc.)
assert!(!output.contains("event: content_block_stop"), "Should NOT have content_block_stop for partial stream");
assert!(
!output.contains("event: content_block_stop"),
"Should NOT have content_block_stop for partial stream"
);
// Should NOT have completion events
assert!(!output.contains("event: message_delta"), "Should NOT have message_delta");
assert!(!output.contains("event: message_stop"), "Should NOT have message_stop");
assert!(
!output.contains("event: message_delta"),
"Should NOT have message_delta"
);
assert!(
!output.contains("event: message_stop"),
"Should NOT have message_stop"
);
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
println!("✓ Partial transformation: OpenAI → Anthropic (stream interrupted)");
println!("✓ Injected: message_start, content_block_start at beginning");
println!("✓ Incremental deltas: {} events (ALL content preserved!)", delta_count);
println!(
"✓ Incremental deltas: {} events (ALL content preserved!)",
delta_count
);
println!("✓ NO completion events (partial stream, no [DONE])");
println!("✓ Buffer maintains Anthropic protocol for active streams\n");
}
@ -452,11 +523,12 @@ data: [DONE]"#;
let mut buffer = AnthropicMessagesStreamBuffer::new();
for raw_event in stream_iter {
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
let transformed_event =
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event);
}
let output_bytes = buffer.into_bytes();
let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
@ -467,32 +539,71 @@ data: [DONE]"#;
assert!(!output_bytes.is_empty(), "Should have output");
// Should have lifecycle events (injected by buffer)
assert!(output.contains("event: message_start"), "Should have message_start (injected)");
assert!(output.contains("event: content_block_start"), "Should have content_block_start");
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
assert!(output.contains("event: message_delta"), "Should have message_delta");
assert!(output.contains("event: message_stop"), "Should have message_stop");
assert!(
output.contains("event: message_start"),
"Should have message_start (injected)"
);
assert!(
output.contains("event: content_block_start"),
"Should have content_block_start"
);
assert!(
output.contains("event: content_block_stop"),
"Should have content_block_stop (injected)"
);
assert!(
output.contains("event: message_delta"),
"Should have message_delta"
);
assert!(
output.contains("event: message_stop"),
"Should have message_stop"
);
// Should have tool_use content block
assert!(output.contains("\"type\":\"tool_use\""), "Should have tool_use type");
assert!(output.contains("\"name\":\"get_weather\""), "Should have correct function name");
assert!(output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""), "Should have correct tool call ID");
assert!(
output.contains("\"type\":\"tool_use\""),
"Should have tool_use type"
);
assert!(
output.contains("\"name\":\"get_weather\""),
"Should have correct function name"
);
assert!(
output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""),
"Should have correct tool call ID"
);
// Count input_json_delta events - should match the number of argument chunks
let delta_count = output.matches("event: content_block_delta").count();
assert!(delta_count >= 8, "Should have at least 8 input_json_delta events");
assert!(
delta_count >= 8,
"Should have at least 8 input_json_delta events"
);
// Verify argument deltas are present
assert!(output.contains("\"type\":\"input_json_delta\""), "Should have input_json_delta type");
assert!(output.contains("\"partial_json\":"), "Should have partial_json field");
assert!(
output.contains("\"type\":\"input_json_delta\""),
"Should have input_json_delta type"
);
assert!(
output.contains("\"partial_json\":"),
"Should have partial_json field"
);
// Verify the accumulated arguments contain the location
assert!(output.contains("San"), "Arguments should contain 'San'");
assert!(output.contains("Francisco"), "Arguments should contain 'Francisco'");
assert!(
output.contains("Francisco"),
"Arguments should contain 'Francisco'"
);
assert!(output.contains("CA"), "Arguments should contain 'CA'");
// Verify stop reason is tool_use
assert!(output.contains("\"stop_reason\":\"tool_use\""), "Should have stop_reason as tool_use");
assert!(
output.contains("\"stop_reason\":\"tool_use\""),
"Should have stop_reason as tool_use"
);
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));

View file

@ -6,6 +6,12 @@ pub struct OpenAIChatCompletionsStreamBuffer {
buffered_events: Vec<SseEvent>,
}
impl Default for OpenAIChatCompletionsStreamBuffer {
fn default() -> Self {
Self::new()
}
}
impl OpenAIChatCompletionsStreamBuffer {
pub fn new() -> Self {
Self {
@ -26,7 +32,7 @@ impl SseStreamBufferTrait for OpenAIChatCompletionsStreamBuffer {
self.buffered_events.push(event);
}
fn into_bytes(&mut self) -> Vec<u8> {
fn to_bytes(&mut self) -> Vec<u8> {
// No finalization needed for OpenAI Chat Completions
// The [DONE] marker is already handled by the transformation layer
let mut buffer = Vec::new();

View file

@ -1,7 +1,7 @@
pub mod sse;
pub mod sse_chunk_processor;
pub mod amazon_bedrock_binary_frame;
pub mod anthropic_streaming_buffer;
pub mod chat_completions_streaming_buffer;
pub mod passthrough_streaming_buffer;
pub mod responses_api_streaming_buffer;
pub mod sse;
pub mod sse_chunk_processor;

View file

@ -6,6 +6,12 @@ pub struct PassthroughStreamBuffer {
buffered_events: Vec<SseEvent>,
}
impl Default for PassthroughStreamBuffer {
fn default() -> Self {
Self::new()
}
}
impl PassthroughStreamBuffer {
pub fn new() -> Self {
Self {
@ -30,7 +36,7 @@ impl SseStreamBufferTrait for PassthroughStreamBuffer {
self.buffered_events.push(event);
}
fn into_bytes(&mut self) -> Vec<u8> {
fn to_bytes(&mut self) -> Vec<u8> {
// No finalization needed for passthrough - just convert accumulated events to bytes
let mut buffer = Vec::new();
for event in self.buffered_events.drain(..) {
@ -44,7 +50,7 @@ impl SseStreamBufferTrait for PassthroughStreamBuffer {
#[cfg(test)]
mod tests {
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
use crate::apis::streaming_shapes::sse::{SseStreamIter, SseStreamBufferTrait};
use crate::apis::streaming_shapes::sse::{SseStreamBufferTrait, SseStreamIter};
#[test]
fn test_chat_completions_passthrough_buffer() {
@ -73,7 +79,7 @@ mod tests {
buffer.add_transformed_event(event);
}
let output_bytes = buffer.into_bytes();
let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (ChatCompletions - Passthrough):");
@ -84,7 +90,11 @@ mod tests {
assert!(!output_bytes.is_empty());
assert!(output.contains("chatcmpl-123"));
assert!(output.contains("[DONE]"));
assert_eq!(raw_input.trim(), output.trim(), "Passthrough should preserve input");
assert_eq!(
raw_input.trim(),
output.trim(),
"Passthrough should preserve input"
);
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));

View file

@ -1,10 +1,10 @@
use std::collections::HashMap;
use log::debug;
use crate::apis::openai_responses::{
ResponsesAPIStreamEvent, ResponsesAPIResponse, OutputItem, OutputItemStatus,
ResponseStatus, TextConfig, TextFormat, Reasoning,
OutputItem, OutputItemStatus, Reasoning, ResponseStatus, ResponsesAPIResponse,
ResponsesAPIStreamEvent, TextConfig, TextFormat,
};
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
use log::debug;
use std::collections::HashMap;
/// Helper to convert ResponseAPIStreamEvent to SseEvent
fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
@ -16,10 +16,17 @@ fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done",
ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta",
ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done",
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => "response.function_call_arguments.delta",
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => "response.function_call_arguments.done",
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => {
"response.function_call_arguments.delta"
}
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => {
"response.function_call_arguments.done"
}
unknown => {
debug!("Unknown ResponsesAPIStreamEvent type encountered: {:?}", unknown);
debug!(
"Unknown ResponsesAPIStreamEvent type encountered: {:?}",
unknown
);
"unknown"
}
};
@ -85,6 +92,12 @@ pub struct ResponsesAPIStreamBuffer {
buffered_events: Vec<SseEvent>,
}
impl Default for ResponsesAPIStreamBuffer {
fn default() -> Self {
Self::new()
}
}
impl ResponsesAPIStreamBuffer {
pub fn new() -> Self {
Self {
@ -112,7 +125,11 @@ impl ResponsesAPIStreamBuffer {
}
fn generate_item_id(prefix: &str) -> String {
format!("{}_{}", prefix, uuid::Uuid::new_v4().to_string().replace("-", ""))
format!(
"{}_{}",
prefix,
uuid::Uuid::new_v4().to_string().replace("-", "")
)
}
fn get_or_create_item_id(&mut self, output_index: i32, prefix: &str) -> String {
@ -160,7 +177,13 @@ impl ResponsesAPIStreamBuffer {
}
/// Create output_item.added event for tool call
fn create_tool_call_added_event(&mut self, output_index: i32, item_id: &str, call_id: &str, name: &str) -> SseEvent {
fn create_tool_call_added_event(
&mut self,
output_index: i32,
item_id: &str,
call_id: &str,
name: &str,
) -> SseEvent {
let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded {
output_index,
item: OutputItem::FunctionCall {
@ -237,9 +260,15 @@ impl ResponsesAPIStreamBuffer {
// Emit done events for all accumulated content
// Text content done events
let text_items: Vec<_> = self.text_content.iter().map(|(id, content)| (id.clone(), content.clone())).collect();
let text_items: Vec<_> = self
.text_content
.iter()
.map(|(id, content)| (id.clone(), content.clone()))
.collect();
for (item_id, content) in text_items {
let output_index = self.output_items_added.iter()
let output_index = self
.output_items_added
.iter()
.find(|(_, id)| **id == item_id)
.map(|(idx, _)| *idx)
.unwrap_or(0);
@ -270,9 +299,15 @@ impl ResponsesAPIStreamBuffer {
}
// Function call done events
let func_items: Vec<_> = self.function_arguments.iter().map(|(id, args)| (id.clone(), args.clone())).collect();
let func_items: Vec<_> = self
.function_arguments
.iter()
.map(|(id, args)| (id.clone(), args.clone()))
.collect();
for (item_id, arguments) in func_items {
let output_index = self.output_items_added.iter()
let output_index = self
.output_items_added
.iter()
.find(|(_, id)| **id == item_id)
.map(|(idx, _)| *idx)
.unwrap_or(0);
@ -286,9 +321,16 @@ impl ResponsesAPIStreamBuffer {
};
events.push(event_to_sse(args_done_event));
let (call_id, name) = self.tool_call_metadata.get(&output_index)
let (call_id, name) = self
.tool_call_metadata
.get(&output_index)
.cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
.unwrap_or_else(|| {
(
format!("call_{}", uuid::Uuid::new_v4()),
"unknown".to_string(),
)
});
let seq2 = self.next_sequence_number();
let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone {
@ -315,9 +357,16 @@ impl ResponsesAPIStreamBuffer {
if let Some(item_id) = self.output_items_added.get(&output_index) {
// Check if this is a function call
if let Some(arguments) = self.function_arguments.get(item_id) {
let (call_id, name) = self.tool_call_metadata.get(&output_index)
let (call_id, name) = self
.tool_call_metadata
.get(&output_index)
.cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
.unwrap_or_else(|| {
(
format!("call_{}", uuid::Uuid::new_v4()),
"unknown".to_string(),
)
});
output_items.push(OutputItem::FunctionCall {
id: item_id.clone(),
@ -397,9 +446,9 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
let mut events = Vec::new();
// Capture upstream metadata from ResponseCreated or ResponseInProgress if present
match stream_event {
ResponsesAPIStreamEvent::ResponseCreated { response, .. } |
ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => {
match stream_event.as_ref() {
ResponsesAPIStreamEvent::ResponseCreated { response, .. }
| ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => {
if self.upstream_response_metadata.is_none() {
// Store the full upstream response as our metadata template
self.upstream_response_metadata = Some(response.clone());
@ -418,11 +467,16 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
if !self.created_emitted {
// Initialize metadata from first event if needed
if self.response_id.is_none() {
self.response_id = Some(format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", "")));
self.created_at = Some(std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64);
self.response_id = Some(format!(
"resp_{}",
uuid::Uuid::new_v4().to_string().replace("-", "")
));
self.created_at = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64,
);
self.model = Some("unknown".to_string()); // Will be set by caller if available
}
@ -436,58 +490,95 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
}
// Process the delta event
match stream_event {
ResponsesAPIStreamEvent::ResponseOutputTextDelta { output_index, delta, .. } => {
match stream_event.as_ref() {
ResponsesAPIStreamEvent::ResponseOutputTextDelta {
output_index,
delta,
..
} => {
let item_id = self.get_or_create_item_id(*output_index, "msg");
// Emit output_item.added if this is the first time we see this output index
if !self.output_items_added.contains_key(output_index) {
self.output_items_added.insert(*output_index, item_id.clone());
self.output_items_added
.insert(*output_index, item_id.clone());
events.push(self.create_output_item_added_event(*output_index, &item_id));
}
// Accumulate text content
self.text_content.entry(item_id.clone())
self.text_content
.entry(item_id.clone())
.and_modify(|content| content.push_str(delta))
.or_insert_with(|| delta.clone());
// Emit text delta with filled-in item_id and sequence_number
let mut delta_event = stream_event.clone();
if let ResponsesAPIStreamEvent::ResponseOutputTextDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
let mut delta_event = stream_event.as_ref().clone();
if let ResponsesAPIStreamEvent::ResponseOutputTextDelta {
item_id: ref mut id,
sequence_number: ref mut seq,
..
} = &mut delta_event
{
*id = item_id;
*seq = self.next_sequence_number();
}
events.push(event_to_sse(delta_event));
}
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { output_index, delta, call_id, name, .. } => {
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
output_index,
delta,
call_id,
name,
..
} => {
let item_id = self.get_or_create_item_id(*output_index, "fc");
// Store metadata if provided (from initial tool call event)
if let (Some(cid), Some(n)) = (call_id, name) {
self.tool_call_metadata.insert(*output_index, (cid.clone(), n.clone()));
self.tool_call_metadata
.insert(*output_index, (cid.clone(), n.clone()));
}
// Emit output_item.added if this is the first time we see this tool call
if !self.output_items_added.contains_key(output_index) {
self.output_items_added.insert(*output_index, item_id.clone());
self.output_items_added
.insert(*output_index, item_id.clone());
// For tool calls, we need call_id and name from metadata
// These should now be populated from the event itself
let (call_id, name) = self.tool_call_metadata.get(output_index)
let (call_id, name) = self
.tool_call_metadata
.get(output_index)
.cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
.unwrap_or_else(|| {
(
format!("call_{}", uuid::Uuid::new_v4()),
"unknown".to_string(),
)
});
events.push(self.create_tool_call_added_event(*output_index, &item_id, &call_id, &name));
events.push(self.create_tool_call_added_event(
*output_index,
&item_id,
&call_id,
&name,
));
}
// Accumulate function arguments
self.function_arguments.entry(item_id.clone())
self.function_arguments
.entry(item_id.clone())
.and_modify(|args| args.push_str(delta))
.or_insert_with(|| delta.clone());
// Emit function call arguments delta with filled-in item_id and sequence_number
let mut delta_event = stream_event.clone();
if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
let mut delta_event = stream_event.as_ref().clone();
if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
item_id: ref mut id,
sequence_number: ref mut seq,
..
} = &mut delta_event
{
*id = item_id;
*seq = self.next_sequence_number();
}
@ -495,7 +586,7 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
}
_ => {
// For other event types, just pass through with sequence number
let other_event = stream_event.clone();
let other_event = stream_event.as_ref().clone();
// TODO: Add sequence number to other event types if needed
events.push(event_to_sse(other_event));
}
@ -505,8 +596,7 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
self.buffered_events.extend(events);
}
fn into_bytes(&mut self) -> Vec<u8> {
fn to_bytes(&mut self) -> Vec<u8> {
// For Responses API, we need special handling:
// - Most events are already in buffered_events from add_transformed_event
// - We should NOT finalize here - finalization happens when we detect [DONE] or end of stream
@ -525,9 +615,9 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use crate::apis::openai::OpenAIApi;
use crate::apis::streaming_shapes::sse::SseStreamIter;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
#[test]
fn test_chat_completions_to_responses_api_transformation() {
@ -557,11 +647,12 @@ mod tests {
for raw_event in stream_iter {
// Transform the event using the client/upstream APIs
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
let transformed_event =
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event);
}
let output_bytes = buffer.into_bytes();
let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
@ -570,13 +661,34 @@ mod tests {
// Assertions
assert!(!output_bytes.is_empty(), "Should have output");
assert!(output.contains("response.created"), "Should have response.created");
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
assert!(output.contains("response.output_text.delta"), "Should have text deltas");
assert!(output.contains("response.output_text.done"), "Should have text.done");
assert!(output.contains("response.output_item.done"), "Should have output_item.done");
assert!(output.contains("response.completed"), "Should have response.completed");
assert!(
output.contains("response.created"),
"Should have response.created"
);
assert!(
output.contains("response.in_progress"),
"Should have response.in_progress"
);
assert!(
output.contains("response.output_item.added"),
"Should have output_item.added"
);
assert!(
output.contains("response.output_text.delta"),
"Should have text deltas"
);
assert!(
output.contains("response.output_text.done"),
"Should have text.done"
);
assert!(
output.contains("response.output_item.done"),
"Should have output_item.done"
);
assert!(
output.contains("response.completed"),
"Should have response.completed"
);
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
@ -616,7 +728,7 @@ mod tests {
buffer.add_transformed_event(transformed);
}
let output_bytes = buffer.into_bytes();
let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
@ -624,24 +736,55 @@ mod tests {
println!("{}", output);
// Assertions
assert!(output.contains("response.created"), "Should have response.created");
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
assert!(output.contains("\"type\":\"function_call\""), "Should be function_call type");
assert!(output.contains("\"name\":\"get_weather\""), "Should have function name");
assert!(output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""), "Should have correct call_id");
assert!(
output.contains("response.created"),
"Should have response.created"
);
assert!(
output.contains("response.in_progress"),
"Should have response.in_progress"
);
assert!(
output.contains("response.output_item.added"),
"Should have output_item.added"
);
assert!(
output.contains("\"type\":\"function_call\""),
"Should be function_call type"
);
assert!(
output.contains("\"name\":\"get_weather\""),
"Should have function name"
);
assert!(
output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""),
"Should have correct call_id"
);
let delta_count = output.matches("event: response.function_call_arguments.delta").count();
let delta_count = output
.matches("event: response.function_call_arguments.delta")
.count();
assert_eq!(delta_count, 4, "Should have 4 delta events");
assert!(!output.contains("response.function_call_arguments.done"), "Should NOT have arguments.done");
assert!(!output.contains("response.output_item.done"), "Should NOT have output_item.done");
assert!(!output.contains("response.completed"), "Should NOT have response.completed");
assert!(
!output.contains("response.function_call_arguments.done"),
"Should NOT have arguments.done"
);
assert!(
!output.contains("response.output_item.done"),
"Should NOT have output_item.done"
);
assert!(
!output.contains("response.completed"),
"Should NOT have response.completed"
);
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
println!("✓ Lifecycle events: response.created, response.in_progress");
println!("✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'");
println!(
"✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'"
);
println!("✓ Incremental deltas: 4 events (1 initial + 3 argument chunks)");
println!("✓ NO completion events (partial stream, no [DONE])");
println!("✓ Arguments accumulated: '{{\"location\":\"'\n");

View file

@ -1,9 +1,9 @@
use crate::providers::streaming_response::ProviderStreamResponse;
use crate::providers::streaming_response::ProviderStreamResponseType;
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
use crate::apis::streaming_shapes::anthropic_streaming_buffer::AnthropicMessagesStreamBuffer;
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
use crate::apis::streaming_shapes::responses_api_streaming_buffer::ResponsesAPIStreamBuffer;
use crate::providers::streaming_response::ProviderStreamResponse;
use crate::providers::streaming_response::ProviderStreamResponseType;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fmt;
@ -37,7 +37,7 @@ pub trait SseStreamBufferTrait: Send + Sync {
///
/// # Returns
/// Bytes ready for wire transmission (may be empty if no events were accumulated)
fn into_bytes(&mut self) -> Vec<u8>;
fn to_bytes(&mut self) -> Vec<u8>;
}
/// Unified SSE Stream Buffer enum that provides a zero-cost abstraction
@ -45,7 +45,7 @@ pub enum SseStreamBuffer {
Passthrough(PassthroughStreamBuffer),
OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer),
AnthropicMessages(AnthropicMessagesStreamBuffer),
OpenAIResponses(ResponsesAPIStreamBuffer),
OpenAIResponses(Box<ResponsesAPIStreamBuffer>),
}
impl SseStreamBufferTrait for SseStreamBuffer {
@ -58,12 +58,12 @@ impl SseStreamBufferTrait for SseStreamBuffer {
}
}
fn into_bytes(&mut self) -> Vec<u8> {
fn to_bytes(&mut self) -> Vec<u8> {
match self {
Self::Passthrough(buffer) => buffer.into_bytes(),
Self::OpenAIChatCompletions(buffer) => buffer.into_bytes(),
Self::AnthropicMessages(buffer) => buffer.into_bytes(),
Self::OpenAIResponses(buffer) => buffer.into_bytes(),
Self::Passthrough(buffer) => buffer.to_bytes(),
Self::OpenAIChatCompletions(buffer) => buffer.to_bytes(),
Self::AnthropicMessages(buffer) => buffer.to_bytes(),
Self::OpenAIResponses(buffer) => buffer.to_bytes(),
}
}
}
@ -99,7 +99,7 @@ impl SseEvent {
let sse_string: String = response.clone().into();
SseEvent {
data: None, // Data is embedded in sse_transformed_lines
data: None, // Data is embedded in sse_transformed_lines
event: None, // Event type is embedded in sse_transformed_lines
raw_line: sse_string.clone(),
sse_transformed_lines: sse_string,
@ -149,10 +149,8 @@ impl FromStr for SseEvent {
});
}
if trimmed_line.starts_with("data: ") {
let data: String = trimmed_line[6..].to_string(); // Remove "data: " prefix
// Allow empty data content after "data: " prefix
// This handles cases like "data: " followed by newline
if let Some(stripped) = trimmed_line.strip_prefix("data: ") {
let data: String = stripped.to_string();
if data.trim().is_empty() {
return Err(SseParseError {
message: "Empty data field after 'data: ' prefix".to_string(),
@ -166,8 +164,8 @@ impl FromStr for SseEvent {
sse_transformed_lines: line.to_string(),
provider_stream_response: None,
})
} else if trimmed_line.starts_with("event: ") {
let event_type = trimmed_line[7..].to_string();
} else if let Some(stripped) = trimmed_line.strip_prefix("event: ") {
let event_type = stripped.to_string();
if event_type.is_empty() {
return Err(SseParseError {
message: "Empty event field is not a valid SSE event".to_string(),
@ -183,7 +181,10 @@ impl FromStr for SseEvent {
})
} else {
Err(SseParseError {
message: format!("Line does not start with 'data: ' or 'event: ': {}", trimmed_line),
message: format!(
"Line does not start with 'data: ' or 'event: ': {}",
trimmed_line
),
})
}
}
@ -196,16 +197,16 @@ impl fmt::Display for SseEvent {
}
// Into implementation to convert SseEvent to bytes for response buffer
impl Into<Vec<u8>> for SseEvent {
fn into(self) -> Vec<u8> {
impl From<SseEvent> for Vec<u8> {
fn from(val: SseEvent) -> Self {
// For generated events (like ResponsesAPI), sse_transformed_lines already includes trailing \n\n
// For parsed events (like passthrough), we need to add the \n\n separator
if self.sse_transformed_lines.ends_with("\n\n") {
if val.sse_transformed_lines.ends_with("\n\n") {
// Already properly formatted with trailing newlines
self.sse_transformed_lines.into_bytes()
val.sse_transformed_lines.into_bytes()
} else {
// Add SSE event separator
format!("{}\n\n", self.sse_transformed_lines).into_bytes()
format!("{}\n\n", val.sse_transformed_lines).into_bytes()
}
}
}

View file

@ -10,6 +10,12 @@ pub struct SseChunkProcessor {
incomplete_event_buffer: Vec<u8>,
}
impl Default for SseChunkProcessor {
fn default() -> Self {
Self::new()
}
}
impl SseChunkProcessor {
pub fn new() -> Self {
Self {
@ -93,8 +99,8 @@ impl SseChunkProcessor {
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use crate::apis::openai::OpenAIApi;
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
#[test]
fn test_complete_events_process_immediately() {
@ -104,7 +110,9 @@ mod tests {
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
let events = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap();
let events = processor
.process_chunk(chunk1, &client_api, &upstream_api)
.unwrap();
assert_eq!(events.len(), 1);
assert!(!processor.has_buffered_data());
@ -119,18 +127,28 @@ mod tests {
// First chunk with incomplete JSON
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chu";
let events1 = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap();
let events1 = processor
.process_chunk(chunk1, &client_api, &upstream_api)
.unwrap();
assert_eq!(events1.len(), 0, "Incomplete event should not be processed");
assert!(processor.has_buffered_data(), "Incomplete data should be buffered");
assert!(
processor.has_buffered_data(),
"Incomplete data should be buffered"
);
// Second chunk completes the JSON
let chunk2 = b"nk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
let events2 = processor.process_chunk(chunk2, &client_api, &upstream_api).unwrap();
let events2 = processor
.process_chunk(chunk2, &client_api, &upstream_api)
.unwrap();
assert_eq!(events2.len(), 1, "Complete event should be processed");
assert!(!processor.has_buffered_data(), "Buffer should be cleared after completion");
assert!(
!processor.has_buffered_data(),
"Buffer should be cleared after completion"
);
}
#[test]
@ -142,10 +160,15 @@ mod tests {
// Chunk with 2 complete events and 1 incomplete
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"A\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"B\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-125\",\"object\":\"chat.completion.chu";
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap();
let events = processor
.process_chunk(chunk, &client_api, &upstream_api)
.unwrap();
assert_eq!(events.len(), 2, "Two complete events should be processed");
assert!(processor.has_buffered_data(), "Incomplete third event should be buffered");
assert!(
processor.has_buffered_data(),
"Incomplete third event should be buffered"
);
}
#[test]
@ -171,11 +194,23 @@ data: {"type":"content_block_stop","index":0}
Ok(events) => {
println!("Successfully processed {} events", events.len());
for (i, event) in events.iter().enumerate() {
println!("Event {}: event={:?}, has_data={}", i, event.event, event.data.is_some());
println!(
"Event {}: event={:?}, has_data={}",
i,
event.event,
event.data.is_some()
);
}
// Should successfully process both events (signature_delta + content_block_stop)
assert!(events.len() >= 2, "Should process at least 2 complete events (signature_delta + stop), got {}", events.len());
assert!(!processor.has_buffered_data(), "Complete events should not be buffered");
assert!(
events.len() >= 2,
"Should process at least 2 complete events (signature_delta + stop), got {}",
events.len()
);
assert!(
!processor.has_buffered_data(),
"Complete events should not be buffered"
);
}
Err(e) => {
panic!("Failed to process signature_delta chunk - this means SignatureDelta is not properly handled: {}", e);
@ -194,12 +229,21 @@ data: {"type":"content_block_stop","index":0}
// Second event is valid and should be processed
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"unsupported_field_causing_validation_error\":true},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap();
let events = processor
.process_chunk(chunk, &client_api, &upstream_api)
.unwrap();
// Should skip the invalid event and process the valid one
// (If we were buffering all errors, we'd get 0 events and have buffered data)
assert!(events.len() >= 1, "Should process at least the valid event, got {} events", events.len());
assert!(!processor.has_buffered_data(), "Invalid (non-incomplete) events should not be buffered");
assert!(
!events.is_empty(),
"Should process at least the valid event, got {} events",
events.len()
);
assert!(
!processor.has_buffered_data(),
"Invalid (non-incomplete) events should not be buffered"
);
}
#[test]
@ -227,14 +271,27 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text
match result {
Ok(events) => {
println!("Processed {} events (unsupported event should be skipped)", events.len());
println!(
"Processed {} events (unsupported event should be skipped)",
events.len()
);
// Should process the 2 valid text_delta events and skip the unsupported one
// We expect at least 2 events (the valid ones), unsupported should be skipped
assert!(events.len() >= 2, "Should process at least 2 valid events, got {}", events.len());
assert!(!processor.has_buffered_data(), "Unsupported events should be skipped, not buffered");
assert!(
events.len() >= 2,
"Should process at least 2 valid events, got {}",
events.len()
);
assert!(
!processor.has_buffered_data(),
"Unsupported events should be skipped, not buffered"
);
}
Err(e) => {
panic!("Should not fail on unsupported delta type, should skip it: {}", e);
panic!(
"Should not fail on unsupported delta type, should skip it: {}",
e
);
}
}
}