mirror of
https://github.com/katanemo/plano.git
synced 2026-05-08 15:22:43 +02:00
cargo clippy (#660)
This commit is contained in:
parent
c75e7606f9
commit
ca95ffb63d
62 changed files with 1864 additions and 1187 deletions
|
|
@ -66,7 +66,7 @@ impl ApiDefinition for AmazonBedrockApi {
|
|||
|
||||
/// Amazon Bedrock Converse request
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||
pub struct ConverseRequest {
|
||||
/// The model ID or ARN to invoke
|
||||
pub model_id: String,
|
||||
|
|
@ -91,7 +91,7 @@ pub struct ConverseRequest {
|
|||
pub additional_model_response_field_paths: Option<Vec<String>>,
|
||||
/// Performance configuration
|
||||
#[serde(rename = "performanceConfig")]
|
||||
pub performance_config: Option<PerformanceConfiguration>,
|
||||
pub performance_config: Option<InferenceConfiguration>,
|
||||
/// Prompt variables for Prompt management
|
||||
#[serde(rename = "promptVariables")]
|
||||
pub prompt_variables: Option<HashMap<String, PromptVariableValues>>,
|
||||
|
|
@ -105,26 +105,6 @@ pub struct ConverseRequest {
|
|||
pub stream: bool,
|
||||
}
|
||||
|
||||
impl Default for ConverseRequest {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model_id: String::new(),
|
||||
messages: None,
|
||||
system: None,
|
||||
inference_config: None,
|
||||
tool_config: None,
|
||||
guardrail_config: None,
|
||||
additional_model_request_fields: None,
|
||||
additional_model_response_field_paths: None,
|
||||
performance_config: None,
|
||||
prompt_variables: None,
|
||||
request_metadata: None,
|
||||
metadata: None,
|
||||
stream: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Amazon Bedrock ConverseStream request (same structure as Converse)
|
||||
pub type ConverseStreamRequest = ConverseRequest;
|
||||
|
||||
|
|
@ -204,8 +184,8 @@ impl ProviderRequest for ConverseRequest {
|
|||
self.tool_config.as_ref()?.tools.as_ref().map(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.filter_map(|tool| match tool {
|
||||
Tool::ToolSpec { tool_spec } => Some(tool_spec.name.clone()),
|
||||
.map(|tool| match tool {
|
||||
Tool::ToolSpec { tool_spec } => tool_spec.name.clone(),
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
|
|
@ -242,17 +222,14 @@ impl ProviderRequest for ConverseRequest {
|
|||
// Add system messages if present
|
||||
if let Some(system) = &self.system {
|
||||
for sys_block in system {
|
||||
match sys_block {
|
||||
SystemContentBlock::Text { text } => {
|
||||
openai_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(text.clone()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
_ => {} // Skip other system content types
|
||||
if let SystemContentBlock::Text { text } = sys_block {
|
||||
openai_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(text.clone()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -266,7 +243,9 @@ impl ProviderRequest for ConverseRequest {
|
|||
};
|
||||
|
||||
// Extract text from content blocks
|
||||
let content = msg.content.iter()
|
||||
let content = msg
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|block| {
|
||||
if let ContentBlock::Text { text } = block {
|
||||
Some(text.clone())
|
||||
|
|
@ -311,16 +290,14 @@ impl ProviderRequest for ConverseRequest {
|
|||
_ => continue,
|
||||
};
|
||||
|
||||
let content = if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
vec![ContentBlock::Text { text: text.clone() }]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
let content =
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
vec![ContentBlock::Text { text: text.clone() }]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
bedrock_messages.push(crate::apis::amazon_bedrock::Message {
|
||||
role,
|
||||
content,
|
||||
});
|
||||
bedrock_messages.push(crate::apis::amazon_bedrock::Message { role, content });
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
|
@ -369,7 +346,7 @@ pub enum ConverseStreamEvent {
|
|||
ContentBlockDelta(ContentBlockDeltaEvent),
|
||||
ContentBlockStop(ContentBlockStopEvent),
|
||||
MessageStop(MessageStopEvent),
|
||||
Metadata(ConverseStreamMetadataEvent),
|
||||
Metadata(Box<ConverseStreamMetadataEvent>),
|
||||
// Error events
|
||||
InternalServerException(BedrockException),
|
||||
ModelStreamErrorException(BedrockException),
|
||||
|
|
@ -1063,7 +1040,7 @@ impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEve
|
|||
"metadata" => {
|
||||
let event: ConverseStreamMetadataEvent =
|
||||
serde_json::from_slice(payload).map_err(BedrockError::Serialization)?;
|
||||
Ok(ConverseStreamEvent::Metadata(event))
|
||||
Ok(ConverseStreamEvent::Metadata(Box::new(event)))
|
||||
}
|
||||
unknown => Err(BedrockError::Validation {
|
||||
message: format!("Unknown event type: {}", unknown),
|
||||
|
|
@ -1106,10 +1083,10 @@ impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEve
|
|||
}
|
||||
}
|
||||
|
||||
impl Into<String> for ConverseStreamEvent {
|
||||
fn into(self) -> String {
|
||||
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
|
||||
let event_type = match &self {
|
||||
impl From<ConverseStreamEvent> for String {
|
||||
fn from(val: ConverseStreamEvent) -> String {
|
||||
let transformed_json = serde_json::to_string(&val).unwrap_or_default();
|
||||
let event_type = match &val {
|
||||
ConverseStreamEvent::MessageStart { .. } => "message_start",
|
||||
ConverseStreamEvent::ContentBlockStart { .. } => "content_block_start",
|
||||
ConverseStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -286,7 +286,6 @@ pub struct ImageUrl {
|
|||
}
|
||||
|
||||
/// A single message in a chat conversation
|
||||
|
||||
/// A tool call made by the assistant
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
|
|
@ -388,7 +387,7 @@ pub enum StaticContentType {
|
|||
|
||||
/// Chat completions API response
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||
pub struct ChatCompletionsResponse {
|
||||
pub id: String,
|
||||
pub object: Option<String>,
|
||||
|
|
@ -402,22 +401,6 @@ pub struct ChatCompletionsResponse {
|
|||
pub metadata: Option<HashMap<String, Value>>,
|
||||
}
|
||||
|
||||
impl Default for ChatCompletionsResponse {
|
||||
fn default() -> Self {
|
||||
ChatCompletionsResponse {
|
||||
id: String::new(),
|
||||
object: None,
|
||||
created: 0,
|
||||
model: String::new(),
|
||||
choices: vec![],
|
||||
usage: Usage::default(),
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finish reason for completion
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
|
|
@ -431,7 +414,7 @@ pub enum FinishReason {
|
|||
|
||||
/// Token usage information
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
|
|
@ -440,18 +423,6 @@ pub struct Usage {
|
|||
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
||||
}
|
||||
|
||||
impl Default for Usage {
|
||||
fn default() -> Self {
|
||||
Usage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Detailed breakdown of prompt tokens
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
@ -472,7 +443,7 @@ pub struct CompletionTokensDetails {
|
|||
|
||||
/// A single choice in the response
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: ResponseMessage,
|
||||
|
|
@ -480,17 +451,6 @@ pub struct Choice {
|
|||
pub logprobs: Option<Value>,
|
||||
}
|
||||
|
||||
impl Default for Choice {
|
||||
fn default() -> Self {
|
||||
Choice {
|
||||
index: 0,
|
||||
message: ResponseMessage::default(),
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// STREAMING API TYPES
|
||||
// ============================================================================
|
||||
|
|
@ -608,7 +568,6 @@ pub enum OpenAIError {
|
|||
// ============================================================================
|
||||
/// Trait Implementations
|
||||
/// ===========================================================================
|
||||
|
||||
/// Parameterized conversion for ChatCompletionsRequest
|
||||
impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
||||
type Error = OpenAIStreamError;
|
||||
|
|
@ -721,7 +680,7 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
}
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||
return &self.metadata;
|
||||
&self.metadata
|
||||
}
|
||||
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use std::collections::HashMap;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use std::collections::HashMap;
|
||||
|
||||
impl TryFrom<&[u8]> for ResponsesAPIRequest {
|
||||
type Error = serde_json::Error;
|
||||
|
|
@ -172,18 +172,14 @@ pub enum MessageRole {
|
|||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum InputContent {
|
||||
/// Text input
|
||||
InputText {
|
||||
text: String,
|
||||
},
|
||||
InputText { text: String },
|
||||
/// Image input via URL
|
||||
InputImage {
|
||||
image_url: String,
|
||||
detail: Option<String>,
|
||||
},
|
||||
/// File input via URL
|
||||
InputFile {
|
||||
file_url: String,
|
||||
},
|
||||
InputFile { file_url: String },
|
||||
/// Audio input
|
||||
InputAudio {
|
||||
data: Option<String>,
|
||||
|
|
@ -222,9 +218,7 @@ pub struct TextConfig {
|
|||
pub enum TextFormat {
|
||||
Text,
|
||||
JsonObject,
|
||||
JsonSchema {
|
||||
json_schema: serde_json::Value,
|
||||
},
|
||||
JsonSchema { json_schema: serde_json::Value },
|
||||
}
|
||||
|
||||
/// Reasoning effort levels
|
||||
|
|
@ -608,9 +602,7 @@ pub enum OutputContent {
|
|||
transcript: Option<String>,
|
||||
},
|
||||
/// Refusal output
|
||||
Refusal {
|
||||
refusal: String,
|
||||
},
|
||||
Refusal { refusal: String },
|
||||
}
|
||||
|
||||
/// Annotations for output text
|
||||
|
|
@ -663,13 +655,9 @@ pub struct FileSearchResult {
|
|||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum CodeInterpreterOutput {
|
||||
/// Text output
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
Text { text: String },
|
||||
/// Image output
|
||||
Image {
|
||||
image: String,
|
||||
},
|
||||
Image { image: String },
|
||||
}
|
||||
|
||||
/// Response usage statistics
|
||||
|
|
@ -951,9 +939,7 @@ pub enum ResponsesAPIStreamEvent {
|
|||
},
|
||||
|
||||
/// Done event (end of stream)
|
||||
Done {
|
||||
sequence_number: i32,
|
||||
},
|
||||
Done { sequence_number: i32 },
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -1052,12 +1038,19 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Items(content_items) => {
|
||||
content_items.iter().fold(String::new(), |acc, content| {
|
||||
acc + " " + &match content {
|
||||
InputContent::InputText { text } => text.clone(),
|
||||
InputContent::InputImage { .. } => "[Image]".to_string(),
|
||||
InputContent::InputFile { .. } => "[File]".to_string(),
|
||||
InputContent::InputAudio { .. } => "[Audio]".to_string(),
|
||||
}
|
||||
acc + " "
|
||||
+ &match content {
|
||||
InputContent::InputText { text } => text.clone(),
|
||||
InputContent::InputImage { .. } => {
|
||||
"[Image]".to_string()
|
||||
}
|
||||
InputContent::InputFile { .. } => {
|
||||
"[File]".to_string()
|
||||
}
|
||||
InputContent::InputAudio { .. } => {
|
||||
"[Audio]".to_string()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
};
|
||||
|
|
@ -1082,11 +1075,9 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
match &msg.content {
|
||||
MessageContent::Text(text) => Some(text.clone()),
|
||||
MessageContent::Items(content_items) => {
|
||||
content_items.iter().find_map(|content| {
|
||||
match content {
|
||||
InputContent::InputText { text } => Some(text.clone()),
|
||||
_ => None,
|
||||
}
|
||||
content_items.iter().find_map(|content| match content {
|
||||
InputContent::InputText { text } => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -1176,9 +1167,12 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
|
||||
// Extract text from message content
|
||||
let content = match &msg.content {
|
||||
crate::apis::openai_responses::MessageContent::Text(text) => text.clone(),
|
||||
crate::apis::openai_responses::MessageContent::Text(text) => {
|
||||
text.clone()
|
||||
}
|
||||
crate::apis::openai_responses::MessageContent::Items(items) => {
|
||||
items.iter()
|
||||
items
|
||||
.iter()
|
||||
.filter_map(|c| {
|
||||
if let InputContent::InputText { text } = c {
|
||||
Some(text.clone())
|
||||
|
|
@ -1214,7 +1208,8 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
|
||||
// For ResponsesAPI, we need to convert messages back to input format
|
||||
// Extract system messages as instructions
|
||||
let system_text = messages.iter()
|
||||
let system_text = messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role == crate::apis::openai::Role::System)
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
|
|
@ -1233,23 +1228,27 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
// Convert user/assistant messages to InputParam
|
||||
// For simplicity, we'll use the last user message as the input
|
||||
// or combine all non-system messages
|
||||
let input_messages: Vec<_> = messages.iter()
|
||||
let input_messages: Vec<_> = messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role != crate::apis::openai::Role::System)
|
||||
.collect();
|
||||
|
||||
if !input_messages.is_empty() {
|
||||
// If there's only one message, use Text format
|
||||
if input_messages.len() == 1 {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content
|
||||
{
|
||||
self.input = crate::apis::openai_responses::InputParam::Text(text.clone());
|
||||
}
|
||||
} else {
|
||||
// Multiple messages - combine them as text for now
|
||||
// A more sophisticated approach would use InputParam::Items
|
||||
let combined_text = input_messages.iter()
|
||||
let combined_text = input_messages
|
||||
.iter()
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
Some(format!("{}: {}",
|
||||
Some(format!(
|
||||
"{}: {}",
|
||||
match msg.role {
|
||||
crate::apis::openai::Role::User => "User",
|
||||
crate::apis::openai::Role::Assistant => "Assistant",
|
||||
|
|
@ -1274,10 +1273,10 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
// Into<String> Implementation for SSE Formatting
|
||||
// ============================================================================
|
||||
|
||||
impl Into<String> for ResponsesAPIStreamEvent {
|
||||
fn into(self) -> String {
|
||||
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
|
||||
let event_type = match &self {
|
||||
impl From<ResponsesAPIStreamEvent> for String {
|
||||
fn from(val: ResponsesAPIStreamEvent) -> Self {
|
||||
let transformed_json = serde_json::to_string(&val).unwrap_or_default();
|
||||
let event_type = match &val {
|
||||
ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created",
|
||||
ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress",
|
||||
ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed",
|
||||
|
|
@ -1365,10 +1364,10 @@ impl crate::providers::streaming_response::ProviderStreamResponse for ResponsesA
|
|||
|
||||
fn role(&self) -> Option<&str> {
|
||||
match self {
|
||||
ResponsesAPIStreamEvent::ResponseOutputItemDone { item, .. } => match item {
|
||||
OutputItem::Message { role, .. } => Some(role.as_str()),
|
||||
_ => None,
|
||||
},
|
||||
ResponsesAPIStreamEvent::ResponseOutputItemDone {
|
||||
item: OutputItem::Message { role, .. },
|
||||
..
|
||||
} => Some(role.as_str()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,10 +34,7 @@ where
|
|||
}
|
||||
|
||||
pub fn decode_frame(&mut self) -> Option<DecodedFrame> {
|
||||
match self.decoder.decode_frame(&mut self.buffer) {
|
||||
Ok(frame) => Some(frame),
|
||||
Err(_e) => None, // Fatal decode error
|
||||
}
|
||||
self.decoder.decode_frame(&mut self.buffer).ok()
|
||||
}
|
||||
|
||||
pub fn buffer_mut(&mut self) -> &mut B {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use std::collections::HashSet;
|
||||
|
||||
|
|
@ -31,6 +31,12 @@ pub struct AnthropicMessagesStreamBuffer {
|
|||
model: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for AnthropicMessagesStreamBuffer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl AnthropicMessagesStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
|
@ -154,7 +160,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
let message_start =
|
||||
AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
|
|
@ -169,7 +176,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
let message_start =
|
||||
AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
|
|
@ -177,7 +185,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
// Check if ContentBlockStart was sent for this index
|
||||
if !self.has_content_block_start_been_sent(index) {
|
||||
// Inject ContentBlockStart before delta
|
||||
let content_block_start = AnthropicMessagesStreamBuffer::create_content_block_start_event();
|
||||
let content_block_start =
|
||||
AnthropicMessagesStreamBuffer::create_content_block_start_event();
|
||||
self.buffered_events.push(content_block_start);
|
||||
self.set_content_block_start_sent(index);
|
||||
self.needs_content_block_stop = true;
|
||||
|
|
@ -189,7 +198,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
MessagesStreamEvent::MessageDelta { usage, .. } => {
|
||||
// Inject ContentBlockStop before message_delta
|
||||
if self.needs_content_block_stop {
|
||||
let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event();
|
||||
let content_block_stop =
|
||||
AnthropicMessagesStreamBuffer::create_content_block_stop_event();
|
||||
self.buffered_events.push(content_block_stop);
|
||||
self.needs_content_block_stop = false;
|
||||
}
|
||||
|
|
@ -199,10 +209,10 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
if let Some(last_event) = self.buffered_events.last_mut() {
|
||||
if let Some(ProviderStreamResponseType::MessagesStreamEvent(
|
||||
MessagesStreamEvent::MessageDelta {
|
||||
usage: last_usage,
|
||||
..
|
||||
}
|
||||
)) = &mut last_event.provider_stream_response {
|
||||
usage: last_usage, ..
|
||||
},
|
||||
)) = &mut last_event.provider_stream_response
|
||||
{
|
||||
// Merge: take stop_reason from first, usage from second (if non-zero)
|
||||
if usage.input_tokens > 0 || usage.output_tokens > 0 {
|
||||
*last_usage = usage.clone();
|
||||
|
|
@ -243,7 +253,7 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
}
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
fn to_bytes(&mut self) -> Vec<u8> {
|
||||
// Convert all accumulated events to bytes and clear buffer
|
||||
// NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta
|
||||
// or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming.
|
||||
|
|
@ -276,10 +286,10 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::streaming_shapes::sse::SseStreamIter;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_complete_transformation() {
|
||||
|
|
@ -308,11 +318,12 @@ data: [DONE]"#;
|
|||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
let transformed_event =
|
||||
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output_bytes = buffer.to_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
|
|
@ -321,25 +332,54 @@ data: [DONE]"#;
|
|||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("event: message_start"), "Should have message_start");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
|
||||
assert!(
|
||||
output.contains("event: message_start"),
|
||||
"Should have message_start"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: content_block_start"),
|
||||
"Should have content_block_start (injected)"
|
||||
);
|
||||
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert_eq!(delta_count, 2, "Should have exactly 2 content_block_delta events");
|
||||
assert_eq!(
|
||||
delta_count, 2,
|
||||
"Should have exactly 2 content_block_delta events"
|
||||
);
|
||||
|
||||
// Verify both pieces of content are present
|
||||
assert!(output.contains("\"text\":\"Hello\""), "Should have first content delta 'Hello'");
|
||||
assert!(output.contains("\"text\":\" world\""), "Should have second content delta ' world'");
|
||||
assert!(
|
||||
output.contains("\"text\":\"Hello\""),
|
||||
"Should have first content delta 'Hello'"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"text\":\" world\""),
|
||||
"Should have second content delta ' world'"
|
||||
);
|
||||
|
||||
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
|
||||
assert!(output.contains("event: message_delta"), "Should have message_delta");
|
||||
assert!(output.contains("event: message_stop"), "Should have message_stop");
|
||||
assert!(
|
||||
output.contains("event: content_block_stop"),
|
||||
"Should have content_block_stop (injected)"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: message_delta"),
|
||||
"Should have message_delta"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: message_stop"),
|
||||
"Should have message_stop"
|
||||
);
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Complete transformation: OpenAI ChatCompletions → Anthropic Messages API");
|
||||
println!("✓ Injected lifecycle events: message_start, content_block_start, content_block_stop");
|
||||
println!("✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)", delta_count);
|
||||
println!(
|
||||
"✓ Injected lifecycle events: message_start, content_block_start, content_block_stop"
|
||||
);
|
||||
println!(
|
||||
"✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)",
|
||||
delta_count
|
||||
);
|
||||
println!("✓ Complete stream with message_stop");
|
||||
println!("✓ Proper Anthropic protocol sequencing\n");
|
||||
}
|
||||
|
|
@ -369,11 +409,12 @@ data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890
|
|||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
let transformed_event =
|
||||
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output_bytes = buffer.to_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
|
|
@ -382,31 +423,61 @@ data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890
|
|||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("event: message_start"), "Should have message_start");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
|
||||
assert!(
|
||||
output.contains("event: message_start"),
|
||||
"Should have message_start"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: content_block_start"),
|
||||
"Should have content_block_start (injected)"
|
||||
);
|
||||
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert_eq!(delta_count, 3, "Should have exactly 3 content_block_delta events");
|
||||
assert_eq!(
|
||||
delta_count, 3,
|
||||
"Should have exactly 3 content_block_delta events"
|
||||
);
|
||||
|
||||
// Verify all three pieces of content are present
|
||||
assert!(output.contains("\"text\":\"The weather\""), "Should have first content delta");
|
||||
assert!(output.contains("\"text\":\" in San Francisco\""), "Should have second content delta");
|
||||
assert!(output.contains("\"text\":\" is\""), "Should have third content delta");
|
||||
assert!(
|
||||
output.contains("\"text\":\"The weather\""),
|
||||
"Should have first content delta"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"text\":\" in San Francisco\""),
|
||||
"Should have second content delta"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"text\":\" is\""),
|
||||
"Should have third content delta"
|
||||
);
|
||||
|
||||
// For partial streams (no finish_reason, no [DONE]), we do NOT inject content_block_stop
|
||||
// because the stream may continue. This is correct behavior - only inject lifecycle events
|
||||
// when we have explicit signals from upstream (finish_reason, [DONE], etc.)
|
||||
assert!(!output.contains("event: content_block_stop"), "Should NOT have content_block_stop for partial stream");
|
||||
assert!(
|
||||
!output.contains("event: content_block_stop"),
|
||||
"Should NOT have content_block_stop for partial stream"
|
||||
);
|
||||
|
||||
// Should NOT have completion events
|
||||
assert!(!output.contains("event: message_delta"), "Should NOT have message_delta");
|
||||
assert!(!output.contains("event: message_stop"), "Should NOT have message_stop");
|
||||
assert!(
|
||||
!output.contains("event: message_delta"),
|
||||
"Should NOT have message_delta"
|
||||
);
|
||||
assert!(
|
||||
!output.contains("event: message_stop"),
|
||||
"Should NOT have message_stop"
|
||||
);
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Partial transformation: OpenAI → Anthropic (stream interrupted)");
|
||||
println!("✓ Injected: message_start, content_block_start at beginning");
|
||||
println!("✓ Incremental deltas: {} events (ALL content preserved!)", delta_count);
|
||||
println!(
|
||||
"✓ Incremental deltas: {} events (ALL content preserved!)",
|
||||
delta_count
|
||||
);
|
||||
println!("✓ NO completion events (partial stream, no [DONE])");
|
||||
println!("✓ Buffer maintains Anthropic protocol for active streams\n");
|
||||
}
|
||||
|
|
@ -452,11 +523,12 @@ data: [DONE]"#;
|
|||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
let transformed_event =
|
||||
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output_bytes = buffer.to_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
|
|
@ -467,32 +539,71 @@ data: [DONE]"#;
|
|||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
|
||||
// Should have lifecycle events (injected by buffer)
|
||||
assert!(output.contains("event: message_start"), "Should have message_start (injected)");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start");
|
||||
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
|
||||
assert!(output.contains("event: message_delta"), "Should have message_delta");
|
||||
assert!(output.contains("event: message_stop"), "Should have message_stop");
|
||||
assert!(
|
||||
output.contains("event: message_start"),
|
||||
"Should have message_start (injected)"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: content_block_start"),
|
||||
"Should have content_block_start"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: content_block_stop"),
|
||||
"Should have content_block_stop (injected)"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: message_delta"),
|
||||
"Should have message_delta"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: message_stop"),
|
||||
"Should have message_stop"
|
||||
);
|
||||
|
||||
// Should have tool_use content block
|
||||
assert!(output.contains("\"type\":\"tool_use\""), "Should have tool_use type");
|
||||
assert!(output.contains("\"name\":\"get_weather\""), "Should have correct function name");
|
||||
assert!(output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""), "Should have correct tool call ID");
|
||||
assert!(
|
||||
output.contains("\"type\":\"tool_use\""),
|
||||
"Should have tool_use type"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"name\":\"get_weather\""),
|
||||
"Should have correct function name"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""),
|
||||
"Should have correct tool call ID"
|
||||
);
|
||||
|
||||
// Count input_json_delta events - should match the number of argument chunks
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert!(delta_count >= 8, "Should have at least 8 input_json_delta events");
|
||||
assert!(
|
||||
delta_count >= 8,
|
||||
"Should have at least 8 input_json_delta events"
|
||||
);
|
||||
|
||||
// Verify argument deltas are present
|
||||
assert!(output.contains("\"type\":\"input_json_delta\""), "Should have input_json_delta type");
|
||||
assert!(output.contains("\"partial_json\":"), "Should have partial_json field");
|
||||
assert!(
|
||||
output.contains("\"type\":\"input_json_delta\""),
|
||||
"Should have input_json_delta type"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"partial_json\":"),
|
||||
"Should have partial_json field"
|
||||
);
|
||||
|
||||
// Verify the accumulated arguments contain the location
|
||||
assert!(output.contains("San"), "Arguments should contain 'San'");
|
||||
assert!(output.contains("Francisco"), "Arguments should contain 'Francisco'");
|
||||
assert!(
|
||||
output.contains("Francisco"),
|
||||
"Arguments should contain 'Francisco'"
|
||||
);
|
||||
assert!(output.contains("CA"), "Arguments should contain 'CA'");
|
||||
|
||||
// Verify stop reason is tool_use
|
||||
assert!(output.contains("\"stop_reason\":\"tool_use\""), "Should have stop_reason as tool_use");
|
||||
assert!(
|
||||
output.contains("\"stop_reason\":\"tool_use\""),
|
||||
"Should have stop_reason as tool_use"
|
||||
);
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
|
|
|
|||
|
|
@ -6,6 +6,12 @@ pub struct OpenAIChatCompletionsStreamBuffer {
|
|||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl Default for OpenAIChatCompletionsStreamBuffer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIChatCompletionsStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
|
@ -26,7 +32,7 @@ impl SseStreamBufferTrait for OpenAIChatCompletionsStreamBuffer {
|
|||
self.buffered_events.push(event);
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
fn to_bytes(&mut self) -> Vec<u8> {
|
||||
// No finalization needed for OpenAI Chat Completions
|
||||
// The [DONE] marker is already handled by the transformation layer
|
||||
let mut buffer = Vec::new();
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
pub mod sse;
|
||||
pub mod sse_chunk_processor;
|
||||
pub mod amazon_bedrock_binary_frame;
|
||||
pub mod anthropic_streaming_buffer;
|
||||
pub mod chat_completions_streaming_buffer;
|
||||
pub mod passthrough_streaming_buffer;
|
||||
pub mod responses_api_streaming_buffer;
|
||||
pub mod sse;
|
||||
pub mod sse_chunk_processor;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,12 @@ pub struct PassthroughStreamBuffer {
|
|||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl Default for PassthroughStreamBuffer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PassthroughStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
|
@ -30,7 +36,7 @@ impl SseStreamBufferTrait for PassthroughStreamBuffer {
|
|||
self.buffered_events.push(event);
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
fn to_bytes(&mut self) -> Vec<u8> {
|
||||
// No finalization needed for passthrough - just convert accumulated events to bytes
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
|
|
@ -44,7 +50,7 @@ impl SseStreamBufferTrait for PassthroughStreamBuffer {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
|
||||
use crate::apis::streaming_shapes::sse::{SseStreamIter, SseStreamBufferTrait};
|
||||
use crate::apis::streaming_shapes::sse::{SseStreamBufferTrait, SseStreamIter};
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_passthrough_buffer() {
|
||||
|
|
@ -73,7 +79,7 @@ mod tests {
|
|||
buffer.add_transformed_event(event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output_bytes = buffer.to_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ChatCompletions - Passthrough):");
|
||||
|
|
@ -84,7 +90,11 @@ mod tests {
|
|||
assert!(!output_bytes.is_empty());
|
||||
assert!(output.contains("chatcmpl-123"));
|
||||
assert!(output.contains("[DONE]"));
|
||||
assert_eq!(raw_input.trim(), output.trim(), "Passthrough should preserve input");
|
||||
assert_eq!(
|
||||
raw_input.trim(),
|
||||
output.trim(),
|
||||
"Passthrough should preserve input"
|
||||
);
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
use std::collections::HashMap;
|
||||
use log::debug;
|
||||
use crate::apis::openai_responses::{
|
||||
ResponsesAPIStreamEvent, ResponsesAPIResponse, OutputItem, OutputItemStatus,
|
||||
ResponseStatus, TextConfig, TextFormat, Reasoning,
|
||||
OutputItem, OutputItemStatus, Reasoning, ResponseStatus, ResponsesAPIResponse,
|
||||
ResponsesAPIStreamEvent, TextConfig, TextFormat,
|
||||
};
|
||||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
use log::debug;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Helper to convert ResponseAPIStreamEvent to SseEvent
|
||||
fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
|
||||
|
|
@ -16,10 +16,17 @@ fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
|
|||
ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done",
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta",
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done",
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => "response.function_call_arguments.delta",
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => "response.function_call_arguments.done",
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => {
|
||||
"response.function_call_arguments.delta"
|
||||
}
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => {
|
||||
"response.function_call_arguments.done"
|
||||
}
|
||||
unknown => {
|
||||
debug!("Unknown ResponsesAPIStreamEvent type encountered: {:?}", unknown);
|
||||
debug!(
|
||||
"Unknown ResponsesAPIStreamEvent type encountered: {:?}",
|
||||
unknown
|
||||
);
|
||||
"unknown"
|
||||
}
|
||||
};
|
||||
|
|
@ -85,6 +92,12 @@ pub struct ResponsesAPIStreamBuffer {
|
|||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl Default for ResponsesAPIStreamBuffer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponsesAPIStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
|
@ -112,7 +125,11 @@ impl ResponsesAPIStreamBuffer {
|
|||
}
|
||||
|
||||
fn generate_item_id(prefix: &str) -> String {
|
||||
format!("{}_{}", prefix, uuid::Uuid::new_v4().to_string().replace("-", ""))
|
||||
format!(
|
||||
"{}_{}",
|
||||
prefix,
|
||||
uuid::Uuid::new_v4().to_string().replace("-", "")
|
||||
)
|
||||
}
|
||||
|
||||
fn get_or_create_item_id(&mut self, output_index: i32, prefix: &str) -> String {
|
||||
|
|
@ -160,7 +177,13 @@ impl ResponsesAPIStreamBuffer {
|
|||
}
|
||||
|
||||
/// Create output_item.added event for tool call
|
||||
fn create_tool_call_added_event(&mut self, output_index: i32, item_id: &str, call_id: &str, name: &str) -> SseEvent {
|
||||
fn create_tool_call_added_event(
|
||||
&mut self,
|
||||
output_index: i32,
|
||||
item_id: &str,
|
||||
call_id: &str,
|
||||
name: &str,
|
||||
) -> SseEvent {
|
||||
let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded {
|
||||
output_index,
|
||||
item: OutputItem::FunctionCall {
|
||||
|
|
@ -237,9 +260,15 @@ impl ResponsesAPIStreamBuffer {
|
|||
// Emit done events for all accumulated content
|
||||
|
||||
// Text content done events
|
||||
let text_items: Vec<_> = self.text_content.iter().map(|(id, content)| (id.clone(), content.clone())).collect();
|
||||
let text_items: Vec<_> = self
|
||||
.text_content
|
||||
.iter()
|
||||
.map(|(id, content)| (id.clone(), content.clone()))
|
||||
.collect();
|
||||
for (item_id, content) in text_items {
|
||||
let output_index = self.output_items_added.iter()
|
||||
let output_index = self
|
||||
.output_items_added
|
||||
.iter()
|
||||
.find(|(_, id)| **id == item_id)
|
||||
.map(|(idx, _)| *idx)
|
||||
.unwrap_or(0);
|
||||
|
|
@ -270,9 +299,15 @@ impl ResponsesAPIStreamBuffer {
|
|||
}
|
||||
|
||||
// Function call done events
|
||||
let func_items: Vec<_> = self.function_arguments.iter().map(|(id, args)| (id.clone(), args.clone())).collect();
|
||||
let func_items: Vec<_> = self
|
||||
.function_arguments
|
||||
.iter()
|
||||
.map(|(id, args)| (id.clone(), args.clone()))
|
||||
.collect();
|
||||
for (item_id, arguments) in func_items {
|
||||
let output_index = self.output_items_added.iter()
|
||||
let output_index = self
|
||||
.output_items_added
|
||||
.iter()
|
||||
.find(|(_, id)| **id == item_id)
|
||||
.map(|(idx, _)| *idx)
|
||||
.unwrap_or(0);
|
||||
|
|
@ -286,9 +321,16 @@ impl ResponsesAPIStreamBuffer {
|
|||
};
|
||||
events.push(event_to_sse(args_done_event));
|
||||
|
||||
let (call_id, name) = self.tool_call_metadata.get(&output_index)
|
||||
let (call_id, name) = self
|
||||
.tool_call_metadata
|
||||
.get(&output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
.unwrap_or_else(|| {
|
||||
(
|
||||
format!("call_{}", uuid::Uuid::new_v4()),
|
||||
"unknown".to_string(),
|
||||
)
|
||||
});
|
||||
|
||||
let seq2 = self.next_sequence_number();
|
||||
let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone {
|
||||
|
|
@ -315,9 +357,16 @@ impl ResponsesAPIStreamBuffer {
|
|||
if let Some(item_id) = self.output_items_added.get(&output_index) {
|
||||
// Check if this is a function call
|
||||
if let Some(arguments) = self.function_arguments.get(item_id) {
|
||||
let (call_id, name) = self.tool_call_metadata.get(&output_index)
|
||||
let (call_id, name) = self
|
||||
.tool_call_metadata
|
||||
.get(&output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
.unwrap_or_else(|| {
|
||||
(
|
||||
format!("call_{}", uuid::Uuid::new_v4()),
|
||||
"unknown".to_string(),
|
||||
)
|
||||
});
|
||||
|
||||
output_items.push(OutputItem::FunctionCall {
|
||||
id: item_id.clone(),
|
||||
|
|
@ -397,9 +446,9 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
|||
let mut events = Vec::new();
|
||||
|
||||
// Capture upstream metadata from ResponseCreated or ResponseInProgress if present
|
||||
match stream_event {
|
||||
ResponsesAPIStreamEvent::ResponseCreated { response, .. } |
|
||||
ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => {
|
||||
match stream_event.as_ref() {
|
||||
ResponsesAPIStreamEvent::ResponseCreated { response, .. }
|
||||
| ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => {
|
||||
if self.upstream_response_metadata.is_none() {
|
||||
// Store the full upstream response as our metadata template
|
||||
self.upstream_response_metadata = Some(response.clone());
|
||||
|
|
@ -418,11 +467,16 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
|||
if !self.created_emitted {
|
||||
// Initialize metadata from first event if needed
|
||||
if self.response_id.is_none() {
|
||||
self.response_id = Some(format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", "")));
|
||||
self.created_at = Some(std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64);
|
||||
self.response_id = Some(format!(
|
||||
"resp_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace("-", "")
|
||||
));
|
||||
self.created_at = Some(
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64,
|
||||
);
|
||||
self.model = Some("unknown".to_string()); // Will be set by caller if available
|
||||
}
|
||||
|
||||
|
|
@ -436,58 +490,95 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
|||
}
|
||||
|
||||
// Process the delta event
|
||||
match stream_event {
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDelta { output_index, delta, .. } => {
|
||||
match stream_event.as_ref() {
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDelta {
|
||||
output_index,
|
||||
delta,
|
||||
..
|
||||
} => {
|
||||
let item_id = self.get_or_create_item_id(*output_index, "msg");
|
||||
|
||||
// Emit output_item.added if this is the first time we see this output index
|
||||
if !self.output_items_added.contains_key(output_index) {
|
||||
self.output_items_added.insert(*output_index, item_id.clone());
|
||||
self.output_items_added
|
||||
.insert(*output_index, item_id.clone());
|
||||
events.push(self.create_output_item_added_event(*output_index, &item_id));
|
||||
}
|
||||
|
||||
// Accumulate text content
|
||||
self.text_content.entry(item_id.clone())
|
||||
self.text_content
|
||||
.entry(item_id.clone())
|
||||
.and_modify(|content| content.push_str(delta))
|
||||
.or_insert_with(|| delta.clone());
|
||||
|
||||
// Emit text delta with filled-in item_id and sequence_number
|
||||
let mut delta_event = stream_event.clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseOutputTextDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
|
||||
let mut delta_event = stream_event.as_ref().clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseOutputTextDelta {
|
||||
item_id: ref mut id,
|
||||
sequence_number: ref mut seq,
|
||||
..
|
||||
} = &mut delta_event
|
||||
{
|
||||
*id = item_id;
|
||||
*seq = self.next_sequence_number();
|
||||
}
|
||||
events.push(event_to_sse(delta_event));
|
||||
}
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { output_index, delta, call_id, name, .. } => {
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
|
||||
output_index,
|
||||
delta,
|
||||
call_id,
|
||||
name,
|
||||
..
|
||||
} => {
|
||||
let item_id = self.get_or_create_item_id(*output_index, "fc");
|
||||
|
||||
// Store metadata if provided (from initial tool call event)
|
||||
if let (Some(cid), Some(n)) = (call_id, name) {
|
||||
self.tool_call_metadata.insert(*output_index, (cid.clone(), n.clone()));
|
||||
self.tool_call_metadata
|
||||
.insert(*output_index, (cid.clone(), n.clone()));
|
||||
}
|
||||
|
||||
// Emit output_item.added if this is the first time we see this tool call
|
||||
if !self.output_items_added.contains_key(output_index) {
|
||||
self.output_items_added.insert(*output_index, item_id.clone());
|
||||
self.output_items_added
|
||||
.insert(*output_index, item_id.clone());
|
||||
|
||||
// For tool calls, we need call_id and name from metadata
|
||||
// These should now be populated from the event itself
|
||||
let (call_id, name) = self.tool_call_metadata.get(output_index)
|
||||
let (call_id, name) = self
|
||||
.tool_call_metadata
|
||||
.get(output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
.unwrap_or_else(|| {
|
||||
(
|
||||
format!("call_{}", uuid::Uuid::new_v4()),
|
||||
"unknown".to_string(),
|
||||
)
|
||||
});
|
||||
|
||||
events.push(self.create_tool_call_added_event(*output_index, &item_id, &call_id, &name));
|
||||
events.push(self.create_tool_call_added_event(
|
||||
*output_index,
|
||||
&item_id,
|
||||
&call_id,
|
||||
&name,
|
||||
));
|
||||
}
|
||||
|
||||
// Accumulate function arguments
|
||||
self.function_arguments.entry(item_id.clone())
|
||||
self.function_arguments
|
||||
.entry(item_id.clone())
|
||||
.and_modify(|args| args.push_str(delta))
|
||||
.or_insert_with(|| delta.clone());
|
||||
|
||||
// Emit function call arguments delta with filled-in item_id and sequence_number
|
||||
let mut delta_event = stream_event.clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
|
||||
let mut delta_event = stream_event.as_ref().clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
|
||||
item_id: ref mut id,
|
||||
sequence_number: ref mut seq,
|
||||
..
|
||||
} = &mut delta_event
|
||||
{
|
||||
*id = item_id;
|
||||
*seq = self.next_sequence_number();
|
||||
}
|
||||
|
|
@ -495,7 +586,7 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
|||
}
|
||||
_ => {
|
||||
// For other event types, just pass through with sequence number
|
||||
let other_event = stream_event.clone();
|
||||
let other_event = stream_event.as_ref().clone();
|
||||
// TODO: Add sequence number to other event types if needed
|
||||
events.push(event_to_sse(other_event));
|
||||
}
|
||||
|
|
@ -505,8 +596,7 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
|||
self.buffered_events.extend(events);
|
||||
}
|
||||
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
fn to_bytes(&mut self) -> Vec<u8> {
|
||||
// For Responses API, we need special handling:
|
||||
// - Most events are already in buffered_events from add_transformed_event
|
||||
// - We should NOT finalize here - finalization happens when we detect [DONE] or end of stream
|
||||
|
|
@ -525,9 +615,9 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::streaming_shapes::sse::SseStreamIter;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_to_responses_api_transformation() {
|
||||
|
|
@ -557,11 +647,12 @@ mod tests {
|
|||
|
||||
for raw_event in stream_iter {
|
||||
// Transform the event using the client/upstream APIs
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
let transformed_event =
|
||||
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output_bytes = buffer.to_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
|
||||
|
|
@ -570,13 +661,34 @@ mod tests {
|
|||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("response.created"), "Should have response.created");
|
||||
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
|
||||
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
|
||||
assert!(output.contains("response.output_text.delta"), "Should have text deltas");
|
||||
assert!(output.contains("response.output_text.done"), "Should have text.done");
|
||||
assert!(output.contains("response.output_item.done"), "Should have output_item.done");
|
||||
assert!(output.contains("response.completed"), "Should have response.completed");
|
||||
assert!(
|
||||
output.contains("response.created"),
|
||||
"Should have response.created"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.in_progress"),
|
||||
"Should have response.in_progress"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.output_item.added"),
|
||||
"Should have output_item.added"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.output_text.delta"),
|
||||
"Should have text deltas"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.output_text.done"),
|
||||
"Should have text.done"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.output_item.done"),
|
||||
"Should have output_item.done"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.completed"),
|
||||
"Should have response.completed"
|
||||
);
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
|
|
@ -616,7 +728,7 @@ mod tests {
|
|||
buffer.add_transformed_event(transformed);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output_bytes = buffer.to_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
|
||||
|
|
@ -624,24 +736,55 @@ mod tests {
|
|||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(output.contains("response.created"), "Should have response.created");
|
||||
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
|
||||
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
|
||||
assert!(output.contains("\"type\":\"function_call\""), "Should be function_call type");
|
||||
assert!(output.contains("\"name\":\"get_weather\""), "Should have function name");
|
||||
assert!(output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""), "Should have correct call_id");
|
||||
assert!(
|
||||
output.contains("response.created"),
|
||||
"Should have response.created"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.in_progress"),
|
||||
"Should have response.in_progress"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.output_item.added"),
|
||||
"Should have output_item.added"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"type\":\"function_call\""),
|
||||
"Should be function_call type"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"name\":\"get_weather\""),
|
||||
"Should have function name"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""),
|
||||
"Should have correct call_id"
|
||||
);
|
||||
|
||||
let delta_count = output.matches("event: response.function_call_arguments.delta").count();
|
||||
let delta_count = output
|
||||
.matches("event: response.function_call_arguments.delta")
|
||||
.count();
|
||||
assert_eq!(delta_count, 4, "Should have 4 delta events");
|
||||
|
||||
assert!(!output.contains("response.function_call_arguments.done"), "Should NOT have arguments.done");
|
||||
assert!(!output.contains("response.output_item.done"), "Should NOT have output_item.done");
|
||||
assert!(!output.contains("response.completed"), "Should NOT have response.completed");
|
||||
assert!(
|
||||
!output.contains("response.function_call_arguments.done"),
|
||||
"Should NOT have arguments.done"
|
||||
);
|
||||
assert!(
|
||||
!output.contains("response.output_item.done"),
|
||||
"Should NOT have output_item.done"
|
||||
);
|
||||
assert!(
|
||||
!output.contains("response.completed"),
|
||||
"Should NOT have response.completed"
|
||||
);
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Lifecycle events: response.created, response.in_progress");
|
||||
println!("✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'");
|
||||
println!(
|
||||
"✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'"
|
||||
);
|
||||
println!("✓ Incremental deltas: 4 events (1 initial + 3 argument chunks)");
|
||||
println!("✓ NO completion events (partial stream, no [DONE])");
|
||||
println!("✓ Arguments accumulated: '{{\"location\":\"'\n");
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
|
||||
use crate::apis::streaming_shapes::anthropic_streaming_buffer::AnthropicMessagesStreamBuffer;
|
||||
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
|
||||
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
|
||||
use crate::apis::streaming_shapes::responses_api_streaming_buffer::ResponsesAPIStreamBuffer;
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
|
@ -37,7 +37,7 @@ pub trait SseStreamBufferTrait: Send + Sync {
|
|||
///
|
||||
/// # Returns
|
||||
/// Bytes ready for wire transmission (may be empty if no events were accumulated)
|
||||
fn into_bytes(&mut self) -> Vec<u8>;
|
||||
fn to_bytes(&mut self) -> Vec<u8>;
|
||||
}
|
||||
|
||||
/// Unified SSE Stream Buffer enum that provides a zero-cost abstraction
|
||||
|
|
@ -45,7 +45,7 @@ pub enum SseStreamBuffer {
|
|||
Passthrough(PassthroughStreamBuffer),
|
||||
OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer),
|
||||
AnthropicMessages(AnthropicMessagesStreamBuffer),
|
||||
OpenAIResponses(ResponsesAPIStreamBuffer),
|
||||
OpenAIResponses(Box<ResponsesAPIStreamBuffer>),
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for SseStreamBuffer {
|
||||
|
|
@ -58,12 +58,12 @@ impl SseStreamBufferTrait for SseStreamBuffer {
|
|||
}
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
fn to_bytes(&mut self) -> Vec<u8> {
|
||||
match self {
|
||||
Self::Passthrough(buffer) => buffer.into_bytes(),
|
||||
Self::OpenAIChatCompletions(buffer) => buffer.into_bytes(),
|
||||
Self::AnthropicMessages(buffer) => buffer.into_bytes(),
|
||||
Self::OpenAIResponses(buffer) => buffer.into_bytes(),
|
||||
Self::Passthrough(buffer) => buffer.to_bytes(),
|
||||
Self::OpenAIChatCompletions(buffer) => buffer.to_bytes(),
|
||||
Self::AnthropicMessages(buffer) => buffer.to_bytes(),
|
||||
Self::OpenAIResponses(buffer) => buffer.to_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -99,7 +99,7 @@ impl SseEvent {
|
|||
let sse_string: String = response.clone().into();
|
||||
|
||||
SseEvent {
|
||||
data: None, // Data is embedded in sse_transformed_lines
|
||||
data: None, // Data is embedded in sse_transformed_lines
|
||||
event: None, // Event type is embedded in sse_transformed_lines
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
|
|
@ -149,10 +149,8 @@ impl FromStr for SseEvent {
|
|||
});
|
||||
}
|
||||
|
||||
if trimmed_line.starts_with("data: ") {
|
||||
let data: String = trimmed_line[6..].to_string(); // Remove "data: " prefix
|
||||
// Allow empty data content after "data: " prefix
|
||||
// This handles cases like "data: " followed by newline
|
||||
if let Some(stripped) = trimmed_line.strip_prefix("data: ") {
|
||||
let data: String = stripped.to_string();
|
||||
if data.trim().is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty data field after 'data: ' prefix".to_string(),
|
||||
|
|
@ -166,8 +164,8 @@ impl FromStr for SseEvent {
|
|||
sse_transformed_lines: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else if trimmed_line.starts_with("event: ") {
|
||||
let event_type = trimmed_line[7..].to_string();
|
||||
} else if let Some(stripped) = trimmed_line.strip_prefix("event: ") {
|
||||
let event_type = stripped.to_string();
|
||||
if event_type.is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty event field is not a valid SSE event".to_string(),
|
||||
|
|
@ -183,7 +181,10 @@ impl FromStr for SseEvent {
|
|||
})
|
||||
} else {
|
||||
Err(SseParseError {
|
||||
message: format!("Line does not start with 'data: ' or 'event: ': {}", trimmed_line),
|
||||
message: format!(
|
||||
"Line does not start with 'data: ' or 'event: ': {}",
|
||||
trimmed_line
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -196,16 +197,16 @@ impl fmt::Display for SseEvent {
|
|||
}
|
||||
|
||||
// Into implementation to convert SseEvent to bytes for response buffer
|
||||
impl Into<Vec<u8>> for SseEvent {
|
||||
fn into(self) -> Vec<u8> {
|
||||
impl From<SseEvent> for Vec<u8> {
|
||||
fn from(val: SseEvent) -> Self {
|
||||
// For generated events (like ResponsesAPI), sse_transformed_lines already includes trailing \n\n
|
||||
// For parsed events (like passthrough), we need to add the \n\n separator
|
||||
if self.sse_transformed_lines.ends_with("\n\n") {
|
||||
if val.sse_transformed_lines.ends_with("\n\n") {
|
||||
// Already properly formatted with trailing newlines
|
||||
self.sse_transformed_lines.into_bytes()
|
||||
val.sse_transformed_lines.into_bytes()
|
||||
} else {
|
||||
// Add SSE event separator
|
||||
format!("{}\n\n", self.sse_transformed_lines).into_bytes()
|
||||
format!("{}\n\n", val.sse_transformed_lines).into_bytes()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,12 @@ pub struct SseChunkProcessor {
|
|||
incomplete_event_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Default for SseChunkProcessor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl SseChunkProcessor {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
|
@ -93,8 +99,8 @@ impl SseChunkProcessor {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
|
||||
#[test]
|
||||
fn test_complete_events_process_immediately() {
|
||||
|
|
@ -104,7 +110,9 @@ mod tests {
|
|||
|
||||
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
|
||||
|
||||
let events = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap();
|
||||
let events = processor
|
||||
.process_chunk(chunk1, &client_api, &upstream_api)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
assert!(!processor.has_buffered_data());
|
||||
|
|
@ -119,18 +127,28 @@ mod tests {
|
|||
// First chunk with incomplete JSON
|
||||
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chu";
|
||||
|
||||
let events1 = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap();
|
||||
let events1 = processor
|
||||
.process_chunk(chunk1, &client_api, &upstream_api)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(events1.len(), 0, "Incomplete event should not be processed");
|
||||
assert!(processor.has_buffered_data(), "Incomplete data should be buffered");
|
||||
assert!(
|
||||
processor.has_buffered_data(),
|
||||
"Incomplete data should be buffered"
|
||||
);
|
||||
|
||||
// Second chunk completes the JSON
|
||||
let chunk2 = b"nk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
|
||||
|
||||
let events2 = processor.process_chunk(chunk2, &client_api, &upstream_api).unwrap();
|
||||
let events2 = processor
|
||||
.process_chunk(chunk2, &client_api, &upstream_api)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(events2.len(), 1, "Complete event should be processed");
|
||||
assert!(!processor.has_buffered_data(), "Buffer should be cleared after completion");
|
||||
assert!(
|
||||
!processor.has_buffered_data(),
|
||||
"Buffer should be cleared after completion"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -142,10 +160,15 @@ mod tests {
|
|||
// Chunk with 2 complete events and 1 incomplete
|
||||
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"A\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"B\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-125\",\"object\":\"chat.completion.chu";
|
||||
|
||||
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap();
|
||||
let events = processor
|
||||
.process_chunk(chunk, &client_api, &upstream_api)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(events.len(), 2, "Two complete events should be processed");
|
||||
assert!(processor.has_buffered_data(), "Incomplete third event should be buffered");
|
||||
assert!(
|
||||
processor.has_buffered_data(),
|
||||
"Incomplete third event should be buffered"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -171,11 +194,23 @@ data: {"type":"content_block_stop","index":0}
|
|||
Ok(events) => {
|
||||
println!("Successfully processed {} events", events.len());
|
||||
for (i, event) in events.iter().enumerate() {
|
||||
println!("Event {}: event={:?}, has_data={}", i, event.event, event.data.is_some());
|
||||
println!(
|
||||
"Event {}: event={:?}, has_data={}",
|
||||
i,
|
||||
event.event,
|
||||
event.data.is_some()
|
||||
);
|
||||
}
|
||||
// Should successfully process both events (signature_delta + content_block_stop)
|
||||
assert!(events.len() >= 2, "Should process at least 2 complete events (signature_delta + stop), got {}", events.len());
|
||||
assert!(!processor.has_buffered_data(), "Complete events should not be buffered");
|
||||
assert!(
|
||||
events.len() >= 2,
|
||||
"Should process at least 2 complete events (signature_delta + stop), got {}",
|
||||
events.len()
|
||||
);
|
||||
assert!(
|
||||
!processor.has_buffered_data(),
|
||||
"Complete events should not be buffered"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("Failed to process signature_delta chunk - this means SignatureDelta is not properly handled: {}", e);
|
||||
|
|
@ -194,12 +229,21 @@ data: {"type":"content_block_stop","index":0}
|
|||
// Second event is valid and should be processed
|
||||
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"unsupported_field_causing_validation_error\":true},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
|
||||
|
||||
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap();
|
||||
let events = processor
|
||||
.process_chunk(chunk, &client_api, &upstream_api)
|
||||
.unwrap();
|
||||
|
||||
// Should skip the invalid event and process the valid one
|
||||
// (If we were buffering all errors, we'd get 0 events and have buffered data)
|
||||
assert!(events.len() >= 1, "Should process at least the valid event, got {} events", events.len());
|
||||
assert!(!processor.has_buffered_data(), "Invalid (non-incomplete) events should not be buffered");
|
||||
assert!(
|
||||
!events.is_empty(),
|
||||
"Should process at least the valid event, got {} events",
|
||||
events.len()
|
||||
);
|
||||
assert!(
|
||||
!processor.has_buffered_data(),
|
||||
"Invalid (non-incomplete) events should not be buffered"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -227,14 +271,27 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text
|
|||
|
||||
match result {
|
||||
Ok(events) => {
|
||||
println!("Processed {} events (unsupported event should be skipped)", events.len());
|
||||
println!(
|
||||
"Processed {} events (unsupported event should be skipped)",
|
||||
events.len()
|
||||
);
|
||||
// Should process the 2 valid text_delta events and skip the unsupported one
|
||||
// We expect at least 2 events (the valid ones), unsupported should be skipped
|
||||
assert!(events.len() >= 2, "Should process at least 2 valid events, got {}", events.len());
|
||||
assert!(!processor.has_buffered_data(), "Unsupported events should be skipped, not buffered");
|
||||
assert!(
|
||||
events.len() >= 2,
|
||||
"Should process at least 2 valid events, got {}",
|
||||
events.len()
|
||||
);
|
||||
assert!(
|
||||
!processor.has_buffered_data(),
|
||||
"Unsupported events should be skipped, not buffered"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("Should not fail on unsupported delta type, should skip it: {}", e);
|
||||
panic!(
|
||||
"Should not fail on unsupported delta type, should skip it: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue