mirror of
https://github.com/katanemo/plano.git
synced 2026-05-08 15:22:43 +02:00
Add support for v1/responses API (#622)
* making first commit. still need to work on streaming respones * making first commit. still need to work on streaming respones * stream buffer implementation with tests * adding grok API keys to workflow * fixed changes based on code review * adding support for bedrock models * fixed issues with translation to claude code --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local>
This commit is contained in:
parent
b01a81927d
commit
a448c6e9cb
38 changed files with 7015 additions and 2955 deletions
|
|
@ -7,7 +7,7 @@ use thiserror::Error;
|
|||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::ProviderStreamResponse;
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
|
||||
// ============================================================================
|
||||
// AMAZON BEDROCK CONVERSE API ENUMERATION
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ use std::collections::HashMap;
|
|||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
|
||||
use crate::providers::response::ProviderResponse;
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::MESSAGES_PATH;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
pub mod amazon_bedrock;
|
||||
pub mod amazon_bedrock_binary_frame;
|
||||
pub mod anthropic;
|
||||
pub mod openai;
|
||||
pub mod sse;
|
||||
pub mod openai_responses;
|
||||
pub mod streaming_shapes;
|
||||
|
||||
// Explicit exports to avoid naming conflicts
|
||||
pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest};
|
||||
|
|
@ -88,8 +88,9 @@ mod tests {
|
|||
fn test_all_variants_method() {
|
||||
// Test that all_variants returns the expected variants
|
||||
let openai_variants = OpenAIApi::all_variants();
|
||||
assert_eq!(openai_variants.len(), 1);
|
||||
assert_eq!(openai_variants.len(), 2);
|
||||
assert!(openai_variants.contains(&OpenAIApi::ChatCompletions));
|
||||
assert!(openai_variants.contains(&OpenAIApi::Responses));
|
||||
|
||||
let anthropic_variants = AnthropicApi::all_variants();
|
||||
assert_eq!(anthropic_variants.len(), 1);
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@ use thiserror::Error;
|
|||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
|
||||
use crate::providers::response::{ProviderResponse, TokenUsage};
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::CHAT_COMPLETIONS_PATH;
|
||||
use crate::{CHAT_COMPLETIONS_PATH, OPENAI_RESPONSES_API_PATH};
|
||||
|
||||
// ============================================================================
|
||||
// OPENAI API ENUMERATION
|
||||
|
|
@ -19,6 +20,7 @@ use crate::CHAT_COMPLETIONS_PATH;
|
|||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum OpenAIApi {
|
||||
ChatCompletions,
|
||||
Responses,
|
||||
// Future APIs can be added here:
|
||||
// Embeddings,
|
||||
// FineTuning,
|
||||
|
|
@ -29,12 +31,14 @@ impl ApiDefinition for OpenAIApi {
|
|||
fn endpoint(&self) -> &'static str {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => CHAT_COMPLETIONS_PATH,
|
||||
OpenAIApi::Responses => OPENAI_RESPONSES_API_PATH,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
match endpoint {
|
||||
CHAT_COMPLETIONS_PATH => Some(OpenAIApi::ChatCompletions),
|
||||
OPENAI_RESPONSES_API_PATH => Some(OpenAIApi::Responses),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
@ -42,23 +46,26 @@ impl ApiDefinition for OpenAIApi {
|
|||
fn supports_streaming(&self) -> bool {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => true,
|
||||
OpenAIApi::Responses => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => true,
|
||||
OpenAIApi::Responses => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_vision(&self) -> bool {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => true,
|
||||
OpenAIApi::Responses => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn all_variants() -> Vec<Self> {
|
||||
vec![OpenAIApi::ChatCompletions]
|
||||
vec![OpenAIApi::ChatCompletions, OpenAIApi::Responses]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1077,8 +1084,9 @@ mod tests {
|
|||
|
||||
// Test all_variants
|
||||
let all_variants = OpenAIApi::all_variants();
|
||||
assert_eq!(all_variants.len(), 1);
|
||||
assert_eq!(all_variants[0], OpenAIApi::ChatCompletions);
|
||||
assert_eq!(all_variants.len(), 2);
|
||||
assert!(all_variants.contains(&OpenAIApi::ChatCompletions));
|
||||
assert!(all_variants.contains(&OpenAIApi::Responses));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
1386
crates/hermesllm/src/apis/openai_responses.rs
Normal file
1386
crates/hermesllm/src/apis/openai_responses.rs
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,7 +1,6 @@
|
|||
use aws_smithy_eventstream::frame::DecodedFrame;
|
||||
use aws_smithy_eventstream::frame::MessageFrameDecoder;
|
||||
use bytes::Buf;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// AWS Event Stream frame decoder wrapper
|
||||
pub struct BedrockBinaryFrameDecoder<B>
|
||||
|
|
@ -10,7 +9,6 @@ where
|
|||
{
|
||||
decoder: MessageFrameDecoder,
|
||||
buffer: B,
|
||||
content_block_start_indices: HashSet<i32>,
|
||||
}
|
||||
|
||||
impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
|
||||
|
|
@ -20,7 +18,6 @@ impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
|
|||
Self {
|
||||
decoder: MessageFrameDecoder::new(),
|
||||
buffer,
|
||||
content_block_start_indices: std::collections::HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -33,7 +30,6 @@ where
|
|||
Self {
|
||||
decoder: MessageFrameDecoder::new(),
|
||||
buffer,
|
||||
content_block_start_indices: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -52,14 +48,4 @@ where
|
|||
pub fn has_remaining(&self) -> bool {
|
||||
self.buffer.has_remaining()
|
||||
}
|
||||
|
||||
/// Check if a content_block_start event has been sent for the given index
|
||||
pub fn has_content_block_start_been_sent(&self, index: i32) -> bool {
|
||||
self.content_block_start_indices.contains(&index)
|
||||
}
|
||||
|
||||
/// Mark that a content_block_start event has been sent for the given index
|
||||
pub fn set_content_block_start_sent(&mut self, index: i32) {
|
||||
self.content_block_start_indices.insert(index);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,507 @@
|
|||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// SSE Stream Buffer for Anthropic Messages API streaming.
|
||||
///
|
||||
/// This buffer manages the wire format for Anthropic Messages API streaming,
|
||||
/// handling the specific event sequencing requirements:
|
||||
/// - MessageStart → ContentBlockStart → ContentBlockDelta(s) → ContentBlockStop → MessageDelta → MessageStop
|
||||
///
|
||||
/// When converting from OpenAI to Anthropic format, this buffer injects the required
|
||||
/// ContentBlockStart and ContentBlockStop events to maintain proper Anthropic protocol.
|
||||
pub struct AnthropicMessagesStreamBuffer {
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
|
||||
/// Track if we've seen a message_start event
|
||||
message_started: bool,
|
||||
|
||||
/// Track content block indices that have received ContentBlockStart events
|
||||
content_block_start_indices: HashSet<i32>,
|
||||
|
||||
/// Track if we need to inject ContentBlockStop before message_delta
|
||||
needs_content_block_stop: bool,
|
||||
|
||||
/// Track if we've seen a MessageDelta (so we need to send MessageStop at the end)
|
||||
seen_message_delta: bool,
|
||||
|
||||
/// Model name to use when generating message_start events
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
impl AnthropicMessagesStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffered_events: Vec::new(),
|
||||
message_started: false,
|
||||
content_block_start_indices: HashSet::new(),
|
||||
needs_content_block_stop: false,
|
||||
seen_message_delta: false,
|
||||
model: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a content_block_start event has been sent for the given index
|
||||
fn has_content_block_start_been_sent(&self, index: i32) -> bool {
|
||||
self.content_block_start_indices.contains(&index)
|
||||
}
|
||||
|
||||
/// Mark that a content_block_start event has been sent for the given index
|
||||
fn set_content_block_start_sent(&mut self, index: i32) {
|
||||
self.content_block_start_indices.insert(index);
|
||||
}
|
||||
|
||||
/// Helper to create and format a ContentBlockStart SSE event
|
||||
fn create_content_block_start_event() -> SseEvent {
|
||||
let content_block_start = MessagesStreamEvent::ContentBlockStart {
|
||||
index: 0,
|
||||
content_block: crate::apis::anthropic::MessagesContentBlock::Text {
|
||||
text: String::new(),
|
||||
cache_control: None,
|
||||
},
|
||||
};
|
||||
let sse_string: String = content_block_start.into();
|
||||
|
||||
SseEvent {
|
||||
data: None,
|
||||
event: Some("content_block_start".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create and format a MessageStart SSE event
|
||||
fn create_message_start_event(model: &str) -> SseEvent {
|
||||
let message_start = MessagesStreamEvent::MessageStart {
|
||||
message: crate::apis::anthropic::MessagesStreamMessage {
|
||||
id: format!("msg_{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
|
||||
obj_type: "message".to_string(),
|
||||
role: crate::apis::anthropic::MessagesRole::Assistant,
|
||||
content: vec![],
|
||||
model: model.to_string(),
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: crate::apis::anthropic::MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
},
|
||||
};
|
||||
let sse_string: String = message_start.into();
|
||||
|
||||
SseEvent {
|
||||
data: None,
|
||||
event: Some("message_start".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create and format a ContentBlockStop SSE event
|
||||
fn create_content_block_stop_event() -> SseEvent {
|
||||
let content_block_stop = MessagesStreamEvent::ContentBlockStop { index: 0 };
|
||||
let sse_string: String = content_block_stop.into();
|
||||
|
||||
SseEvent {
|
||||
data: None,
|
||||
event: Some("content_block_stop".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
// Skip ping messages
|
||||
if event.should_skip() {
|
||||
return;
|
||||
}
|
||||
|
||||
// FIRST: Try to extract model name from the raw event data before transformation
|
||||
// The provider_stream_response has already been transformed to Anthropic format,
|
||||
// so we need to extract the model from the original raw data if available
|
||||
if self.model.is_none() {
|
||||
if let Some(data) = &event.data {
|
||||
// Try to parse as JSON and extract model field
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
|
||||
if let Some(model) = json.get("model").and_then(|m| m.as_str()) {
|
||||
self.model = Some(model.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Match directly on the provider response type to handle event processing
|
||||
// We match on a reference first to determine the type, then move the event
|
||||
match &event.provider_stream_response {
|
||||
Some(ProviderStreamResponseType::MessagesStreamEvent(evt)) => {
|
||||
match evt {
|
||||
MessagesStreamEvent::MessageStart { .. } => {
|
||||
// Add the message_start event
|
||||
self.buffered_events.push(event);
|
||||
self.message_started = true;
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockStart { index, .. } => {
|
||||
let index = *index as i32;
|
||||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
|
||||
// Add the content_block_start event (from tool calls or other sources)
|
||||
self.buffered_events.push(event);
|
||||
self.set_content_block_start_sent(index);
|
||||
self.needs_content_block_stop = true;
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockDelta { index, .. } => {
|
||||
let index = *index as i32;
|
||||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
|
||||
// Check if ContentBlockStart was sent for this index
|
||||
if !self.has_content_block_start_been_sent(index) {
|
||||
// Inject ContentBlockStart before delta
|
||||
let content_block_start = AnthropicMessagesStreamBuffer::create_content_block_start_event();
|
||||
self.buffered_events.push(content_block_start);
|
||||
self.set_content_block_start_sent(index);
|
||||
self.needs_content_block_stop = true;
|
||||
}
|
||||
|
||||
// Content deltas are between ContentBlockStart and ContentBlockStop
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
MessagesStreamEvent::MessageDelta { usage, .. } => {
|
||||
// Inject ContentBlockStop before message_delta
|
||||
if self.needs_content_block_stop {
|
||||
let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event();
|
||||
self.buffered_events.push(content_block_stop);
|
||||
self.needs_content_block_stop = false;
|
||||
}
|
||||
|
||||
// Check if the last event was also a MessageDelta - if so, merge them
|
||||
// This handles Bedrock's split of stop_reason (MessageStop) and usage (Metadata)
|
||||
if let Some(last_event) = self.buffered_events.last_mut() {
|
||||
if let Some(ProviderStreamResponseType::MessagesStreamEvent(
|
||||
MessagesStreamEvent::MessageDelta {
|
||||
usage: last_usage,
|
||||
..
|
||||
}
|
||||
)) = &mut last_event.provider_stream_response {
|
||||
// Merge: take stop_reason from first, usage from second (if non-zero)
|
||||
if usage.input_tokens > 0 || usage.output_tokens > 0 {
|
||||
*last_usage = usage.clone();
|
||||
}
|
||||
// Mark that we've seen MessageDelta (need to send MessageStop later)
|
||||
self.seen_message_delta = true;
|
||||
// Don't push the new event, we've merged it
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// No previous MessageDelta to merge with, add this one
|
||||
self.buffered_events.push(event);
|
||||
self.seen_message_delta = true;
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => {
|
||||
// ContentBlockStop received from upstream (e.g., Bedrock)
|
||||
// Clear the flag so we don't inject another one
|
||||
self.needs_content_block_stop = false;
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
MessagesStreamEvent::MessageStop => {
|
||||
// MessageStop received from upstream (e.g., OpenAI via [DONE])
|
||||
// Clear the flag so we don't inject another one
|
||||
self.seen_message_delta = false;
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
_ => {
|
||||
// Other Anthropic event types (Ping, etc.), just accumulate
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Non-Anthropic events or events without provider_stream_response, just accumulate
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
// Convert all accumulated events to bytes and clear buffer
|
||||
// NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta
|
||||
// or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming.
|
||||
|
||||
// Inject MessageStop after MessageDelta if we've seen one
|
||||
// This completes the Anthropic Messages API event sequence
|
||||
if self.seen_message_delta {
|
||||
let message_stop = MessagesStreamEvent::MessageStop;
|
||||
let sse_string: String = message_stop.into();
|
||||
let message_stop_event = SseEvent {
|
||||
data: None,
|
||||
event: Some("message_stop".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
};
|
||||
self.buffered_events.push(message_stop_event);
|
||||
self.seen_message_delta = false;
|
||||
}
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
let event_bytes: Vec<u8> = event.into();
|
||||
buffer.extend_from_slice(&event_bytes);
|
||||
}
|
||||
buffer
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::streaming_shapes::sse::SseStreamIter;
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_complete_transformation() {
|
||||
// OpenAI ChatCompletions input that will be transformed to Anthropic Messages API
|
||||
let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 1: OpenAI → Anthropic Messages API Complete Transformation");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (OpenAI ChatCompletions):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation (client wants Anthropic, upstream is OpenAI)
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Parse events and apply transformation
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("event: message_start"), "Should have message_start");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
|
||||
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert_eq!(delta_count, 2, "Should have exactly 2 content_block_delta events");
|
||||
|
||||
// Verify both pieces of content are present
|
||||
assert!(output.contains("\"text\":\"Hello\""), "Should have first content delta 'Hello'");
|
||||
assert!(output.contains("\"text\":\" world\""), "Should have second content delta ' world'");
|
||||
|
||||
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
|
||||
assert!(output.contains("event: message_delta"), "Should have message_delta");
|
||||
assert!(output.contains("event: message_stop"), "Should have message_stop");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Complete transformation: OpenAI ChatCompletions → Anthropic Messages API");
|
||||
println!("✓ Injected lifecycle events: message_start, content_block_start, content_block_stop");
|
||||
println!("✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)", delta_count);
|
||||
println!("✓ Complete stream with message_stop");
|
||||
println!("✓ Proper Anthropic protocol sequencing\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_partial_transformation() {
|
||||
// Partial OpenAI ChatCompletions stream - no [DONE]
|
||||
let raw_input = r#"data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"The weather"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" in San Francisco"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" is"},"finish_reason":null}]}"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 2: OpenAI → Anthropic Partial Transformation (NO [DONE])");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (OpenAI ChatCompletions - NO [DONE]):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Parse and transform events
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("event: message_start"), "Should have message_start");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
|
||||
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert_eq!(delta_count, 3, "Should have exactly 3 content_block_delta events");
|
||||
|
||||
// Verify all three pieces of content are present
|
||||
assert!(output.contains("\"text\":\"The weather\""), "Should have first content delta");
|
||||
assert!(output.contains("\"text\":\" in San Francisco\""), "Should have second content delta");
|
||||
assert!(output.contains("\"text\":\" is\""), "Should have third content delta");
|
||||
|
||||
// For partial streams (no finish_reason, no [DONE]), we do NOT inject content_block_stop
|
||||
// because the stream may continue. This is correct behavior - only inject lifecycle events
|
||||
// when we have explicit signals from upstream (finish_reason, [DONE], etc.)
|
||||
assert!(!output.contains("event: content_block_stop"), "Should NOT have content_block_stop for partial stream");
|
||||
|
||||
// Should NOT have completion events
|
||||
assert!(!output.contains("event: message_delta"), "Should NOT have message_delta");
|
||||
assert!(!output.contains("event: message_stop"), "Should NOT have message_stop");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Partial transformation: OpenAI → Anthropic (stream interrupted)");
|
||||
println!("✓ Injected: message_start, content_block_start at beginning");
|
||||
println!("✓ Incremental deltas: {} events (ALL content preserved!)", delta_count);
|
||||
println!("✓ NO completion events (partial stream, no [DONE])");
|
||||
println!("✓ Buffer maintains Anthropic protocol for active streams\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_tool_calling_to_anthropic_transformation() {
|
||||
// OpenAI ChatCompletions tool calling stream
|
||||
let raw_input = r#"data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_2Uzw0AEZQeOex2CP2TKjcLKc","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"obfuscation":"uSpCcO"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"24WSqt08jtf"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"6CleV8twTxkKYg"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"San"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Francisco"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"1XLz89l3v"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":","}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"sh"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" CA"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"obfuscation":"I"}
|
||||
|
||||
data: [DONE]"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 3: OpenAI Tool Calling → Anthropic Messages API Transformation");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (OpenAI ChatCompletions with Tool Calls):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Parse and transform events
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions for tool calling transformation
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
|
||||
// Should have lifecycle events (injected by buffer)
|
||||
assert!(output.contains("event: message_start"), "Should have message_start (injected)");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start");
|
||||
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
|
||||
assert!(output.contains("event: message_delta"), "Should have message_delta");
|
||||
assert!(output.contains("event: message_stop"), "Should have message_stop");
|
||||
|
||||
// Should have tool_use content block
|
||||
assert!(output.contains("\"type\":\"tool_use\""), "Should have tool_use type");
|
||||
assert!(output.contains("\"name\":\"get_weather\""), "Should have correct function name");
|
||||
assert!(output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""), "Should have correct tool call ID");
|
||||
|
||||
// Count input_json_delta events - should match the number of argument chunks
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert!(delta_count >= 8, "Should have at least 8 input_json_delta events");
|
||||
|
||||
// Verify argument deltas are present
|
||||
assert!(output.contains("\"type\":\"input_json_delta\""), "Should have input_json_delta type");
|
||||
assert!(output.contains("\"partial_json\":"), "Should have partial_json field");
|
||||
|
||||
// Verify the accumulated arguments contain the location
|
||||
assert!(output.contains("San"), "Arguments should contain 'San'");
|
||||
assert!(output.contains("Francisco"), "Arguments should contain 'Francisco'");
|
||||
assert!(output.contains("CA"), "Arguments should contain 'CA'");
|
||||
|
||||
// Verify stop reason is tool_use
|
||||
assert!(output.contains("\"stop_reason\":\"tool_use\""), "Should have stop_reason as tool_use");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Complete tool calling transformation: OpenAI → Anthropic Messages API");
|
||||
println!("✓ Injected lifecycle: message_start, content_block_stop");
|
||||
println!("✓ Tool metadata: name='get_weather', id='call_2Uzw0AEZQeOex2CP2TKjcLKc'");
|
||||
println!("✓ Argument deltas: {} events", delta_count);
|
||||
println!("✓ Complete JSON arguments: '{{\"location\":\"San Francisco, CA\"}}'");
|
||||
println!("✓ Stop reason: tool_use");
|
||||
println!("✓ Proper Anthropic tool_use protocol\n");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
|
||||
/// OpenAI Chat Completions SSE Stream Buffer for when client and upstream APIs match.
|
||||
pub struct OpenAIChatCompletionsStreamBuffer {
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl OpenAIChatCompletionsStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffered_events: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for OpenAIChatCompletionsStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
// Skip ping messages
|
||||
if event.should_skip() {
|
||||
return;
|
||||
}
|
||||
|
||||
// For OpenAI Chat Completions, events are already properly transformed
|
||||
// Just accumulate them for later wire transmission
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
// No finalization needed for OpenAI Chat Completions
|
||||
// The [DONE] marker is already handled by the transformation layer
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
let event_bytes: Vec<u8> = event.into();
|
||||
buffer.extend_from_slice(&event_bytes);
|
||||
}
|
||||
buffer
|
||||
}
|
||||
}
|
||||
6
crates/hermesllm/src/apis/streaming_shapes/mod.rs
Normal file
6
crates/hermesllm/src/apis/streaming_shapes/mod.rs
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
pub mod sse;
|
||||
pub mod amazon_bedrock_binary_frame;
|
||||
pub mod anthropic_streaming_buffer;
|
||||
pub mod chat_completions_streaming_buffer;
|
||||
pub mod passthrough_streaming_buffer;
|
||||
pub mod responses_api_streaming_buffer;
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
|
||||
/// Passthrough SSE Stream Buffer for when client and upstream APIs match.
|
||||
pub struct PassthroughStreamBuffer {
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl PassthroughStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffered_events: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for PassthroughStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
// Skip ping messages
|
||||
if event.should_skip() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Skip events with empty transformed lines (e.g., suppressed event-only lines)
|
||||
if event.sse_transformed_lines.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Just accumulate events as-is
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
// No finalization needed for passthrough - just convert accumulated events to bytes
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
let event_bytes: Vec<u8> = event.into();
|
||||
buffer.extend_from_slice(&event_bytes);
|
||||
}
|
||||
buffer
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
|
||||
use crate::apis::streaming_shapes::sse::{SseStreamIter, SseStreamBufferTrait};
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_passthrough_buffer() {
|
||||
let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
|
||||
|
||||
data: [DONE]"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 1: ChatCompletions Passthrough Buffer");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (ChatCompletions):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Parse and process through buffer
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = PassthroughStreamBuffer::new();
|
||||
|
||||
for event in stream_iter {
|
||||
buffer.add_transformed_event(event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ChatCompletions - Passthrough):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty());
|
||||
assert!(output.contains("chatcmpl-123"));
|
||||
assert!(output.contains("[DONE]"));
|
||||
assert_eq!(raw_input.trim(), output.trim(), "Passthrough should preserve input");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Passthrough buffer: input = output (no transformation)");
|
||||
println!("✓ All events preserved including [DONE]");
|
||||
println!("✓ Function calling events preserved\n");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,600 @@
|
|||
use std::collections::HashMap;
|
||||
use log::debug;
|
||||
use crate::apis::openai_responses::{
|
||||
ResponsesAPIStreamEvent, ResponsesAPIResponse, OutputItem, OutputItemStatus,
|
||||
ResponseStatus, TextConfig, TextFormat, Reasoning,
|
||||
};
|
||||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
|
||||
/// Helper to convert ResponseAPIStreamEvent to SseEvent
|
||||
fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
|
||||
let event_type = match &event {
|
||||
ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created",
|
||||
ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress",
|
||||
ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed",
|
||||
ResponsesAPIStreamEvent::ResponseOutputItemAdded { .. } => "response.output_item.added",
|
||||
ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done",
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta",
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done",
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => "response.function_call_arguments.delta",
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => "response.function_call_arguments.done",
|
||||
unknown => {
|
||||
debug!("Unknown ResponsesAPIStreamEvent type encountered: {:?}", unknown);
|
||||
"unknown"
|
||||
}
|
||||
};
|
||||
|
||||
let json_data = match serde_json::to_string(&event) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
debug!("Error serializing ResponsesAPIStreamEvent to JSON: {}", e);
|
||||
String::new()
|
||||
}
|
||||
};
|
||||
let wire_format: String = event.into();
|
||||
|
||||
SseEvent {
|
||||
data: Some(json_data),
|
||||
event: Some(event_type.to_string()),
|
||||
raw_line: wire_format.clone(),
|
||||
sse_transformed_lines: wire_format,
|
||||
provider_stream_response: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// SSE Stream Buffer for ResponsesAPIStreamEvent with full lifecycle management.
|
||||
///
|
||||
/// This buffer manages the wire format for v1/responses streaming, handling
|
||||
/// delta events and emitting complete lifecycle events.
|
||||
///
|
||||
pub struct ResponsesAPIStreamBuffer {
|
||||
/// Sequence number for events
|
||||
sequence_number: i32,
|
||||
|
||||
/// Track item IDs by output index
|
||||
item_ids: HashMap<i32, String>,
|
||||
|
||||
/// Response metadata
|
||||
response_id: Option<String>,
|
||||
model: Option<String>,
|
||||
created_at: Option<i64>,
|
||||
|
||||
/// Lifecycle state flags
|
||||
created_emitted: bool,
|
||||
in_progress_emitted: bool,
|
||||
|
||||
/// Track which output items we've added
|
||||
output_items_added: HashMap<i32, String>, // output_index -> item_id
|
||||
|
||||
/// Accumulated content by item_id
|
||||
text_content: HashMap<String, String>,
|
||||
function_arguments: HashMap<String, String>,
|
||||
|
||||
/// Tool call metadata by output_index
|
||||
tool_call_metadata: HashMap<i32, (String, String)>, // output_index -> (call_id, name)
|
||||
|
||||
/// Final completed response (for logging/tracing/persistence)
|
||||
completed_response: Option<ResponsesAPIResponse>,
|
||||
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl ResponsesAPIStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sequence_number: 0,
|
||||
item_ids: HashMap::new(),
|
||||
response_id: None,
|
||||
model: None,
|
||||
created_at: None,
|
||||
created_emitted: false,
|
||||
in_progress_emitted: false,
|
||||
output_items_added: HashMap::new(),
|
||||
text_content: HashMap::new(),
|
||||
function_arguments: HashMap::new(),
|
||||
tool_call_metadata: HashMap::new(),
|
||||
completed_response: None,
|
||||
buffered_events: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn next_sequence_number(&mut self) -> i32 {
|
||||
let seq = self.sequence_number;
|
||||
self.sequence_number += 1;
|
||||
seq
|
||||
}
|
||||
|
||||
fn generate_item_id(prefix: &str) -> String {
|
||||
format!("{}_{}", prefix, uuid::Uuid::new_v4().to_string().replace("-", ""))
|
||||
}
|
||||
|
||||
fn get_or_create_item_id(&mut self, output_index: i32, prefix: &str) -> String {
|
||||
if let Some(id) = self.item_ids.get(&output_index) {
|
||||
return id.clone();
|
||||
}
|
||||
let id = ResponsesAPIStreamBuffer::generate_item_id(prefix);
|
||||
self.item_ids.insert(output_index, id.clone());
|
||||
id
|
||||
}
|
||||
|
||||
/// Create response.created event
|
||||
fn create_response_created_event(&mut self) -> SseEvent {
|
||||
let response = self.build_response(ResponseStatus::InProgress);
|
||||
let event = ResponsesAPIStreamEvent::ResponseCreated {
|
||||
response,
|
||||
sequence_number: self.next_sequence_number(),
|
||||
};
|
||||
event_to_sse(event)
|
||||
}
|
||||
|
||||
/// Create response.in_progress event
|
||||
fn create_response_in_progress_event(&mut self) -> SseEvent {
|
||||
let response = self.build_response(ResponseStatus::InProgress);
|
||||
let event = ResponsesAPIStreamEvent::ResponseInProgress {
|
||||
response,
|
||||
sequence_number: self.next_sequence_number(),
|
||||
};
|
||||
event_to_sse(event)
|
||||
}
|
||||
|
||||
/// Create output_item.added event for text
|
||||
fn create_output_item_added_event(&mut self, output_index: i32, item_id: &str) -> SseEvent {
|
||||
let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded {
|
||||
output_index,
|
||||
item: OutputItem::Message {
|
||||
id: item_id.to_string(),
|
||||
status: OutputItemStatus::InProgress,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
},
|
||||
sequence_number: self.next_sequence_number(),
|
||||
};
|
||||
event_to_sse(event)
|
||||
}
|
||||
|
||||
/// Create output_item.added event for tool call
|
||||
fn create_tool_call_added_event(&mut self, output_index: i32, item_id: &str, call_id: &str, name: &str) -> SseEvent {
|
||||
let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded {
|
||||
output_index,
|
||||
item: OutputItem::FunctionCall {
|
||||
id: item_id.to_string(),
|
||||
status: OutputItemStatus::InProgress,
|
||||
call_id: call_id.to_string(),
|
||||
name: Some(name.to_string()),
|
||||
arguments: Some(String::new()),
|
||||
},
|
||||
sequence_number: self.next_sequence_number(),
|
||||
};
|
||||
event_to_sse(event)
|
||||
}
|
||||
|
||||
/// Build the base response object with current state
|
||||
fn build_response(&self, status: ResponseStatus) -> ResponsesAPIResponse {
|
||||
ResponsesAPIResponse {
|
||||
id: self.response_id.clone().unwrap_or_default(),
|
||||
object: "response".to_string(),
|
||||
created_at: self.created_at.unwrap_or(0),
|
||||
status,
|
||||
error: None,
|
||||
incomplete_details: None,
|
||||
instructions: None,
|
||||
model: self.model.clone().unwrap_or_else(|| "unknown".to_string()),
|
||||
output: vec![],
|
||||
usage: None,
|
||||
parallel_tool_calls: true,
|
||||
conversation: None,
|
||||
previous_response_id: None,
|
||||
tools: vec![],
|
||||
tool_choice: "auto".to_string(),
|
||||
temperature: 1.0,
|
||||
top_p: 1.0,
|
||||
metadata: HashMap::new(),
|
||||
truncation: Some("disabled".to_string()),
|
||||
max_output_tokens: None,
|
||||
reasoning: Some(Reasoning {
|
||||
effort: None,
|
||||
summary: None,
|
||||
}),
|
||||
store: Some(true),
|
||||
text: Some(TextConfig {
|
||||
format: TextFormat::Text,
|
||||
}),
|
||||
audio: None,
|
||||
modalities: None,
|
||||
service_tier: Some("auto".to_string()),
|
||||
background: Some(false),
|
||||
top_logprobs: Some(0),
|
||||
max_tool_calls: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the completed response after finalization (for logging/tracing/persistence)
|
||||
pub fn get_completed_response(&self) -> Option<&ResponsesAPIResponse> {
|
||||
self.completed_response.as_ref()
|
||||
}
|
||||
|
||||
/// Finalize the response by emitting all *.done events and response.completed.
|
||||
/// Call this when the stream is complete (after seeing [DONE] or end_of_stream).
|
||||
pub fn finalize(&mut self) {
|
||||
let mut events = Vec::new();
|
||||
|
||||
// Emit done events for all accumulated content
|
||||
|
||||
// Text content done events
|
||||
let text_items: Vec<_> = self.text_content.iter().map(|(id, content)| (id.clone(), content.clone())).collect();
|
||||
for (item_id, content) in text_items {
|
||||
let output_index = self.output_items_added.iter()
|
||||
.find(|(_, id)| **id == item_id)
|
||||
.map(|(idx, _)| *idx)
|
||||
.unwrap_or(0);
|
||||
|
||||
let seq1 = self.next_sequence_number();
|
||||
let text_done_event = ResponsesAPIStreamEvent::ResponseOutputTextDone {
|
||||
item_id: item_id.clone(),
|
||||
output_index,
|
||||
content_index: 0,
|
||||
text: content.clone(),
|
||||
logprobs: vec![],
|
||||
sequence_number: seq1,
|
||||
};
|
||||
events.push(event_to_sse(text_done_event));
|
||||
|
||||
let seq2 = self.next_sequence_number();
|
||||
let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone {
|
||||
output_index,
|
||||
item: OutputItem::Message {
|
||||
id: item_id.clone(),
|
||||
status: OutputItemStatus::Completed,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
},
|
||||
sequence_number: seq2,
|
||||
};
|
||||
events.push(event_to_sse(item_done_event));
|
||||
}
|
||||
|
||||
// Function call done events
|
||||
let func_items: Vec<_> = self.function_arguments.iter().map(|(id, args)| (id.clone(), args.clone())).collect();
|
||||
for (item_id, arguments) in func_items {
|
||||
let output_index = self.output_items_added.iter()
|
||||
.find(|(_, id)| **id == item_id)
|
||||
.map(|(idx, _)| *idx)
|
||||
.unwrap_or(0);
|
||||
|
||||
let seq1 = self.next_sequence_number();
|
||||
let args_done_event = ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone {
|
||||
output_index,
|
||||
item_id: item_id.clone(),
|
||||
arguments: arguments.clone(),
|
||||
sequence_number: seq1,
|
||||
};
|
||||
events.push(event_to_sse(args_done_event));
|
||||
|
||||
let (call_id, name) = self.tool_call_metadata.get(&output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
|
||||
let seq2 = self.next_sequence_number();
|
||||
let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone {
|
||||
output_index,
|
||||
item: OutputItem::FunctionCall {
|
||||
id: item_id.clone(),
|
||||
status: OutputItemStatus::Completed,
|
||||
call_id,
|
||||
name: Some(name),
|
||||
arguments: Some(arguments.clone()),
|
||||
},
|
||||
sequence_number: seq2,
|
||||
};
|
||||
events.push(event_to_sse(item_done_event));
|
||||
}
|
||||
|
||||
// Build final response
|
||||
let mut output_items = Vec::new();
|
||||
|
||||
// Add tool calls to output
|
||||
for (item_id, arguments) in &self.function_arguments {
|
||||
let output_index = self.output_items_added.iter()
|
||||
.find(|(_, id)| *id == item_id)
|
||||
.map(|(idx, _)| *idx)
|
||||
.unwrap_or(0);
|
||||
|
||||
let (call_id, name) = self.tool_call_metadata.get(&output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
|
||||
output_items.push(OutputItem::FunctionCall {
|
||||
id: item_id.clone(),
|
||||
status: OutputItemStatus::Completed,
|
||||
call_id,
|
||||
name: Some(name),
|
||||
arguments: Some(arguments.clone()),
|
||||
});
|
||||
}
|
||||
|
||||
let mut final_response = self.build_response(ResponseStatus::Completed);
|
||||
final_response.output = output_items;
|
||||
|
||||
// Store completed response
|
||||
self.completed_response = Some(final_response.clone());
|
||||
|
||||
// Emit response.completed
|
||||
let seq_final = self.next_sequence_number();
|
||||
let completed_event = ResponsesAPIStreamEvent::ResponseCompleted {
|
||||
response: final_response,
|
||||
sequence_number: seq_final,
|
||||
};
|
||||
events.push(event_to_sse(completed_event));
|
||||
|
||||
// Add all finalization events to the buffer
|
||||
self.buffered_events.extend(events);
|
||||
}
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
// Skip ping messages
|
||||
if event.should_skip() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle [DONE] marker - trigger finalization
|
||||
if event.is_done() {
|
||||
self.finalize();
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract the ResponseAPIStreamEvent from the SseEvent's provider_stream_response
|
||||
let provider_response = match event.provider_stream_response.as_ref() {
|
||||
Some(response) => response,
|
||||
None => {
|
||||
eprintln!("Warning: Event missing provider_stream_response");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Extract ResponseAPIStreamEvent from the enum
|
||||
let stream_event = match provider_response {
|
||||
crate::providers::streaming_response::ProviderStreamResponseType::ResponseAPIStreamEvent(evt) => evt,
|
||||
_ => {
|
||||
eprintln!("Warning: Expected ResponseAPIStreamEvent in provider_stream_response");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
|
||||
// Emit lifecycle events if not yet emitted
|
||||
if !self.created_emitted {
|
||||
// Initialize metadata from first event if needed
|
||||
if self.response_id.is_none() {
|
||||
self.response_id = Some(format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", "")));
|
||||
self.created_at = Some(std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64);
|
||||
self.model = Some("unknown".to_string()); // Will be set by caller if available
|
||||
}
|
||||
|
||||
events.push(self.create_response_created_event());
|
||||
self.created_emitted = true;
|
||||
}
|
||||
|
||||
if !self.in_progress_emitted {
|
||||
events.push(self.create_response_in_progress_event());
|
||||
self.in_progress_emitted = true;
|
||||
}
|
||||
|
||||
// Process the delta event
|
||||
match stream_event {
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDelta { output_index, delta, .. } => {
|
||||
let item_id = self.get_or_create_item_id(*output_index, "msg");
|
||||
|
||||
// Emit output_item.added if this is the first time we see this output index
|
||||
if !self.output_items_added.contains_key(output_index) {
|
||||
self.output_items_added.insert(*output_index, item_id.clone());
|
||||
events.push(self.create_output_item_added_event(*output_index, &item_id));
|
||||
}
|
||||
|
||||
// Accumulate text content
|
||||
self.text_content.entry(item_id.clone())
|
||||
.and_modify(|content| content.push_str(delta))
|
||||
.or_insert_with(|| delta.clone());
|
||||
|
||||
// Emit text delta with filled-in item_id and sequence_number
|
||||
let mut delta_event = stream_event.clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseOutputTextDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
|
||||
*id = item_id;
|
||||
*seq = self.next_sequence_number();
|
||||
}
|
||||
events.push(event_to_sse(delta_event));
|
||||
}
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { output_index, delta, call_id, name, .. } => {
|
||||
let item_id = self.get_or_create_item_id(*output_index, "fc");
|
||||
|
||||
// Store metadata if provided (from initial tool call event)
|
||||
if let (Some(cid), Some(n)) = (call_id, name) {
|
||||
self.tool_call_metadata.insert(*output_index, (cid.clone(), n.clone()));
|
||||
}
|
||||
|
||||
// Emit output_item.added if this is the first time we see this tool call
|
||||
if !self.output_items_added.contains_key(output_index) {
|
||||
self.output_items_added.insert(*output_index, item_id.clone());
|
||||
|
||||
// For tool calls, we need call_id and name from metadata
|
||||
// These should now be populated from the event itself
|
||||
let (call_id, name) = self.tool_call_metadata.get(output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
|
||||
events.push(self.create_tool_call_added_event(*output_index, &item_id, &call_id, &name));
|
||||
}
|
||||
|
||||
// Accumulate function arguments
|
||||
self.function_arguments.entry(item_id.clone())
|
||||
.and_modify(|args| args.push_str(delta))
|
||||
.or_insert_with(|| delta.clone());
|
||||
|
||||
// Emit function call arguments delta with filled-in item_id and sequence_number
|
||||
let mut delta_event = stream_event.clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
|
||||
*id = item_id;
|
||||
*seq = self.next_sequence_number();
|
||||
}
|
||||
events.push(event_to_sse(delta_event));
|
||||
}
|
||||
_ => {
|
||||
// For other event types, just pass through with sequence number
|
||||
let other_event = stream_event.clone();
|
||||
// TODO: Add sequence number to other event types if needed
|
||||
events.push(event_to_sse(other_event));
|
||||
}
|
||||
}
|
||||
|
||||
// Store all generated events in the buffer
|
||||
self.buffered_events.extend(events);
|
||||
}
|
||||
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
// For Responses API, we need special handling:
|
||||
// - Most events are already in buffered_events from add_transformed_event
|
||||
// - We should NOT finalize here - finalization happens when we detect [DONE] or end of stream
|
||||
// - Just flush the accumulated events and clear the buffer
|
||||
|
||||
// Convert all accumulated events to bytes and clear buffer
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
let event_bytes: Vec<u8> = event.into();
|
||||
buffer.extend_from_slice(&event_bytes);
|
||||
}
|
||||
buffer
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::streaming_shapes::sse::SseStreamIter;
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_to_responses_api_transformation() {
|
||||
// ChatCompletions input that will be transformed to ResponsesAPI
|
||||
let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 2: ChatCompletions → ResponsesAPI Transformation (with [DONE])");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (ChatCompletions):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation
|
||||
let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Parse events and apply transformation
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = ResponsesAPIStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
// Transform the event using the client/upstream APIs
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("response.created"), "Should have response.created");
|
||||
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
|
||||
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
|
||||
assert!(output.contains("response.output_text.delta"), "Should have text deltas");
|
||||
assert!(output.contains("response.output_text.done"), "Should have text.done");
|
||||
assert!(output.contains("response.output_item.done"), "Should have output_item.done");
|
||||
assert!(output.contains("response.completed"), "Should have response.completed");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Lifecycle events: response.created, response.in_progress, response.completed");
|
||||
println!("✓ Output item lifecycle: output_item.added, output_item.done");
|
||||
println!("✓ Text streaming: output_text.delta (2 deltas), output_text.done");
|
||||
println!("✓ Complete transformation with finalization ([DONE] processed)\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_streaming_incremental_output() {
|
||||
let raw_input = r#"data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_mD5ggLKk3SMKGPFqFdcpKg6q","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"obfuscation":"PCFrpy"}
|
||||
|
||||
data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"TC58A3QEIx8"}
|
||||
|
||||
data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"PK4oFzlVlGTUP5"}"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 3: Partial Streaming - Function Calling (NO [DONE])");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (ChatCompletions - NO [DONE]):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation
|
||||
let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Transform all events
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = ResponsesAPIStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(output.contains("response.created"), "Should have response.created");
|
||||
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
|
||||
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
|
||||
assert!(output.contains("\"type\":\"function_call\""), "Should be function_call type");
|
||||
assert!(output.contains("\"name\":\"get_weather\""), "Should have function name");
|
||||
assert!(output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""), "Should have correct call_id");
|
||||
|
||||
let delta_count = output.matches("event: response.function_call_arguments.delta").count();
|
||||
assert_eq!(delta_count, 4, "Should have 4 delta events");
|
||||
|
||||
assert!(!output.contains("response.function_call_arguments.done"), "Should NOT have arguments.done");
|
||||
assert!(!output.contains("response.output_item.done"), "Should NOT have output_item.done");
|
||||
assert!(!output.contains("response.completed"), "Should NOT have response.completed");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Lifecycle events: response.created, response.in_progress");
|
||||
println!("✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'");
|
||||
println!("✓ Incremental deltas: 4 events (1 initial + 3 argument chunks)");
|
||||
println!("✓ NO completion events (partial stream, no [DONE])");
|
||||
println!("✓ Arguments accumulated: '{{\"location\":\"'\n");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,10 +1,73 @@
|
|||
use crate::providers::response::ProviderStreamResponse;
|
||||
use crate::providers::response::ProviderStreamResponseType;
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
|
||||
use crate::apis::streaming_shapes::anthropic_streaming_buffer::AnthropicMessagesStreamBuffer;
|
||||
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
|
||||
use crate::apis::streaming_shapes::responses_api_streaming_buffer::ResponsesAPIStreamBuffer;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
||||
/// Trait defining the interface for SSE stream buffers.
|
||||
///
|
||||
/// This trait is implemented by both the enum `SseStreamBuffer` (for zero-cost dispatch)
|
||||
/// and individual buffer implementations (for direct use).
|
||||
///
|
||||
pub trait SseStreamBufferTrait: Send + Sync {
|
||||
/// Add a transformed SSE event to the buffer.
|
||||
///
|
||||
/// The buffer may inject additional events as needed based on internal state.
|
||||
/// For example, Anthropic buffers inject ContentBlockStart before the first ContentBlockDelta.
|
||||
///
|
||||
/// All events (original + injected) are accumulated internally for the next `into_bytes()` call.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `event` - A transformed SSE event to accumulate
|
||||
fn add_transformed_event(&mut self, event: SseEvent);
|
||||
|
||||
/// Get bytes for all accumulated events since the last call.
|
||||
///
|
||||
/// This method:
|
||||
/// - Converts all buffered events to wire format bytes
|
||||
/// - Clears the internal event buffer
|
||||
/// - Preserves state for subsequent `add_transformed_event()` calls
|
||||
///
|
||||
/// Call this after processing each chunk of upstream events to get bytes for immediate transmission.
|
||||
///
|
||||
/// # Returns
|
||||
/// Bytes ready for wire transmission (may be empty if no events were accumulated)
|
||||
fn into_bytes(&mut self) -> Vec<u8>;
|
||||
}
|
||||
|
||||
/// Unified SSE Stream Buffer enum that provides a zero-cost abstraction
|
||||
pub enum SseStreamBuffer {
|
||||
Passthrough(PassthroughStreamBuffer),
|
||||
OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer),
|
||||
AnthropicMessages(AnthropicMessagesStreamBuffer),
|
||||
OpenAIResponses(ResponsesAPIStreamBuffer),
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for SseStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
match self {
|
||||
Self::Passthrough(buffer) => buffer.add_transformed_event(event),
|
||||
Self::OpenAIChatCompletions(buffer) => buffer.add_transformed_event(event),
|
||||
Self::AnthropicMessages(buffer) => buffer.add_transformed_event(event),
|
||||
Self::OpenAIResponses(buffer) => buffer.add_transformed_event(event),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
match self {
|
||||
Self::Passthrough(buffer) => buffer.into_bytes(),
|
||||
Self::OpenAIChatCompletions(buffer) => buffer.into_bytes(),
|
||||
Self::AnthropicMessages(buffer) => buffer.into_bytes(),
|
||||
Self::OpenAIResponses(buffer) => buffer.into_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SSE EVENT CONTAINER
|
||||
// ============================================================================
|
||||
|
|
@ -22,16 +85,31 @@ pub struct SseEvent {
|
|||
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
pub sse_transformed_lines: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
|
||||
}
|
||||
|
||||
impl SseEvent {
|
||||
/// Create an SseEvent from a ProviderStreamResponseType
|
||||
/// This is useful for binary frame formats (like Bedrock) that need to be converted to SSE
|
||||
pub fn from_provider_response(response: ProviderStreamResponseType) -> Self {
|
||||
// Convert the provider response to SSE format string
|
||||
let sse_string: String = response.clone().into();
|
||||
|
||||
SseEvent {
|
||||
data: None, // Data is embedded in sse_transformed_lines
|
||||
event: None, // Event type is embedded in sse_transformed_lines
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: Some(response),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this event represents the end of the stream
|
||||
pub fn is_done(&self) -> bool {
|
||||
self.data == Some("[DONE]".into())
|
||||
self.data == Some("[DONE]".into()) || self.event == Some("message_stop".into())
|
||||
}
|
||||
|
||||
/// Check if this event should be skipped during processing
|
||||
|
|
@ -61,23 +139,35 @@ impl FromStr for SseEvent {
|
|||
type Err = SseParseError;
|
||||
|
||||
fn from_str(line: &str) -> Result<Self, Self::Err> {
|
||||
if line.starts_with("data: ") {
|
||||
let data: String = line[6..].to_string(); // Remove "data: " prefix
|
||||
if data.is_empty() {
|
||||
// Trim leading/trailing whitespace for parsing
|
||||
let trimmed_line = line.trim();
|
||||
|
||||
// Skip empty or whitespace-only lines (SSE event separators)
|
||||
if trimmed_line.is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty line (SSE event separator)".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if trimmed_line.starts_with("data: ") {
|
||||
let data: String = trimmed_line[6..].to_string(); // Remove "data: " prefix
|
||||
// Allow empty data content after "data: " prefix
|
||||
// This handles cases like "data: " followed by newline
|
||||
if data.trim().is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty data field is not a valid SSE event".to_string(),
|
||||
message: "Empty data field after 'data: ' prefix".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(SseEvent {
|
||||
data: Some(data),
|
||||
event: None,
|
||||
raw_line: line.to_string(),
|
||||
sse_transform_buffer: line.to_string(),
|
||||
// Preserve original line format for passthrough, use trimmed for transformations
|
||||
sse_transformed_lines: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else if line.starts_with("event: ") {
|
||||
//used by Anthropic
|
||||
let event_type = line[7..].to_string();
|
||||
} else if trimmed_line.starts_with("event: ") {
|
||||
let event_type = trimmed_line[7..].to_string();
|
||||
if event_type.is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty event field is not a valid SSE event".to_string(),
|
||||
|
|
@ -87,12 +177,13 @@ impl FromStr for SseEvent {
|
|||
data: None,
|
||||
event: Some(event_type),
|
||||
raw_line: line.to_string(),
|
||||
sse_transform_buffer: line.to_string(),
|
||||
// Preserve original line format for passthrough, use trimmed for transformations
|
||||
sse_transformed_lines: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else {
|
||||
Err(SseParseError {
|
||||
message: format!("Line does not start with 'data: ' or 'event: ': {}", line),
|
||||
message: format!("Line does not start with 'data: ' or 'event: ': {}", trimmed_line),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -100,14 +191,14 @@ impl FromStr for SseEvent {
|
|||
|
||||
impl fmt::Display for SseEvent {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.sse_transform_buffer)
|
||||
write!(f, "{}", self.sse_transformed_lines)
|
||||
}
|
||||
}
|
||||
|
||||
// Into implementation to convert SseEvent to bytes for response buffer
|
||||
impl Into<Vec<u8>> for SseEvent {
|
||||
fn into(self) -> Vec<u8> {
|
||||
format!("{}\n\n", self.sse_transform_buffer).into_bytes()
|
||||
format!("{}\n\n", self.sse_transformed_lines).into_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue