mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fixed streaming from Anthropic Client to OpenAI
This commit is contained in:
parent
06c71c1392
commit
879c8eeff3
6 changed files with 402 additions and 110 deletions
|
|
@ -546,6 +546,17 @@ impl ProviderStreamResponse for MessagesStreamEvent {
|
|||
}
|
||||
}
|
||||
|
||||
fn event_type(&self) -> Option<&str> {
|
||||
Some(match self {
|
||||
MessagesStreamEvent::MessageStart { .. } => "message_start",
|
||||
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
|
||||
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop",
|
||||
MessagesStreamEvent::MessageDelta { .. } => "message_delta",
|
||||
MessagesStreamEvent::MessageStop => "message_stop",
|
||||
MessagesStreamEvent::Ping => "ping",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
|
|
@ -642,6 +642,9 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
|||
}))
|
||||
}
|
||||
|
||||
fn event_type(&self) -> Option<&str> {
|
||||
None // OpenAI doesn't use event types in SSE
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -52,7 +52,8 @@ mod tests {
|
|||
"#;
|
||||
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
let client_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
let upstream_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test the new simplified architecture - create SseStreamIter directly
|
||||
let sse_iter = SseStreamIter::try_from(sse_data.as_bytes());
|
||||
|
|
@ -68,10 +69,17 @@ mod tests {
|
|||
|
||||
// Test SseEvent properties
|
||||
assert!(!sse_event.is_done());
|
||||
assert!(sse_event.data.contains("Hello"));
|
||||
assert!(sse_event.data.as_ref().unwrap().contains("Hello"));
|
||||
|
||||
// Test that we can parse the event into a provider stream response
|
||||
let provider_response = sse_event.to_provider_stream_response(&api);
|
||||
let transformed_event = SseEvent::try_from((&sse_event, &client_api, &upstream_api));
|
||||
if let Err(e) = &transformed_event {
|
||||
println!("Transform error: {:?}", e);
|
||||
}
|
||||
assert!(transformed_event.is_ok());
|
||||
|
||||
let transformed_event = transformed_event.unwrap();
|
||||
let provider_response = transformed_event.provider_response();
|
||||
assert!(provider_response.is_ok());
|
||||
|
||||
let stream_response = provider_response.unwrap();
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ use std::convert::TryFrom;
|
|||
use std::str::FromStr;
|
||||
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::apis::anthropic::MessagesResponse;
|
||||
|
||||
|
|
@ -19,8 +21,8 @@ pub enum ProviderResponseType {
|
|||
#[derive(Serialize, Debug, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum ProviderStreamResponseType {
|
||||
ChatCompletionsStreamResponse(crate::apis::openai::ChatCompletionsStreamResponse),
|
||||
MessagesStreamEvent(crate::apis::anthropic::MessagesStreamEvent),
|
||||
ChatCompletionsStreamResponse(ChatCompletionsStreamResponse),
|
||||
MessagesStreamEvent(MessagesStreamEvent),
|
||||
}
|
||||
|
||||
pub trait ProviderResponse: Send + Sync {
|
||||
|
|
@ -59,6 +61,38 @@ pub trait ProviderStreamResponse: Send + Sync {
|
|||
/// Get role information if available
|
||||
fn role(&self) -> Option<&str>;
|
||||
|
||||
/// Get event type for SSE streaming (used by Anthropic)
|
||||
fn event_type(&self) -> Option<&str>;
|
||||
}
|
||||
|
||||
impl ProviderStreamResponse for ProviderStreamResponseType {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
match self {
|
||||
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.content_delta(),
|
||||
ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.content_delta(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_final(&self) -> bool {
|
||||
match self {
|
||||
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.is_final(),
|
||||
ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.is_final(),
|
||||
}
|
||||
}
|
||||
|
||||
fn role(&self) -> Option<&str> {
|
||||
match self {
|
||||
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.role(),
|
||||
ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.role(),
|
||||
}
|
||||
}
|
||||
|
||||
fn event_type(&self) -> Option<&str> {
|
||||
match self {
|
||||
ProviderStreamResponseType::ChatCompletionsStreamResponse(_resp) => None, // OpenAI doesn't use event types
|
||||
ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.event_type(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -69,11 +103,17 @@ pub trait ProviderStreamResponse: Send + Sync {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SseEvent {
|
||||
#[serde(rename = "data")]
|
||||
pub data: String, // The JSON payload after "data: "
|
||||
pub data: Option<String>, // The JSON payload after "data: "
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta")
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub raw_line_transformed: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
|
||||
}
|
||||
|
|
@ -81,40 +121,30 @@ pub struct SseEvent {
|
|||
impl SseEvent {
|
||||
/// Check if this event represents the end of the stream
|
||||
pub fn is_done(&self) -> bool {
|
||||
self.data == "[DONE]"
|
||||
self.data == Some("[DONE]".into())
|
||||
}
|
||||
|
||||
/// Check if this event should be skipped during processing
|
||||
/// This includes ping messages and other provider-specific events that don't contain content
|
||||
pub fn should_skip(&self) -> bool {
|
||||
// Skip ping messages (commonly used by providers for connection keep-alive)
|
||||
self.data == r#"{"type": "ping"}"#
|
||||
self.data == Some(r#"{"type": "ping"}"#.into())
|
||||
}
|
||||
|
||||
/// Check if this is an event-only SSE event (no data payload)
|
||||
pub fn is_event_only(&self) -> bool {
|
||||
self.event.is_some() && self.data.is_none()
|
||||
}
|
||||
|
||||
/// Get the parsed provider response if available
|
||||
pub fn provider_response(&self) -> Option<&ProviderStreamResponseType> {
|
||||
pub fn provider_response(&self) -> Result<&dyn ProviderStreamResponse, std::io::Error> {
|
||||
self.provider_stream_response.as_ref()
|
||||
.map(|resp| resp as &dyn ProviderStreamResponse)
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::new(std::io::ErrorKind::NotFound, "Provider response not found")
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse the data field into a ProviderStreamResponse for the given API
|
||||
pub fn to_provider_stream_response(&self, client_api: &SupportedAPIs) -> Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
if self.is_done() {
|
||||
return Err("Cannot parse [DONE] event as ProviderStreamResponse".into());
|
||||
}
|
||||
|
||||
match client_api {
|
||||
SupportedAPIs::OpenAIChatCompletions(_) => {
|
||||
let response: crate::apis::openai::ChatCompletionsStreamResponse =
|
||||
serde_json::from_str(&self.data)?;
|
||||
Ok(Box::new(response))
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
||||
let response: crate::apis::anthropic::MessagesStreamEvent =
|
||||
serde_json::from_str(&self.data)?;
|
||||
Ok(Box::new(response))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for SseEvent {
|
||||
|
|
@ -128,15 +158,30 @@ impl FromStr for SseEvent {
|
|||
message: "Empty data field is not a valid SSE event".to_string(),
|
||||
});
|
||||
}
|
||||
// [DONE] marker is a valid SSE event that indicates end of stream
|
||||
Ok(SseEvent {
|
||||
data,
|
||||
raw_line: format!("{}\n\n", line), // Store complete SSE format
|
||||
data: Some(data),
|
||||
event: None,
|
||||
raw_line: format!("{}\n\n", line),
|
||||
raw_line_transformed: format!("{}\n\n", line),
|
||||
provider_stream_response: None, // Will be populated later via TryFrom
|
||||
})
|
||||
} else if line.starts_with("event: ") { //used by Anthropic
|
||||
let event_type = line[7..].to_string();
|
||||
if event_type.is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty event field is not a valid SSE event".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(SseEvent {
|
||||
data: None,
|
||||
event: Some(event_type),
|
||||
raw_line: format!("{}\n\n", line),
|
||||
raw_line_transformed: format!("{}\n\n", line),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else {
|
||||
Err(SseParseError {
|
||||
message: format!("Line does not start with 'data: ': {}", line),
|
||||
message: format!("Line does not start with 'data: ' or 'event: ': {}", line),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -144,14 +189,14 @@ impl FromStr for SseEvent {
|
|||
|
||||
impl fmt::Display for SseEvent {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.raw_line)
|
||||
write!(f, "{}", self.raw_line_transformed)
|
||||
}
|
||||
}
|
||||
|
||||
// Into implementation to convert SseEvent to bytes for response buffer
|
||||
impl Into<Vec<u8>> for SseEvent {
|
||||
fn into(self) -> Vec<u8> {
|
||||
format!("{}\n\n", self.raw_line).into_bytes()
|
||||
format!("{}\n\n", self.raw_line_transformed).into_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -196,12 +241,11 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
|
|||
}
|
||||
|
||||
// Stream response transformation logic for client API compatibility
|
||||
impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderStreamResponseType {
|
||||
impl TryFrom<(&[u8], &SupportedAPIs, &SupportedAPIs)> for ProviderStreamResponseType {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result<Self, Self::Error> {
|
||||
let upstream_api = provider_id.compatible_api_for_client(client_api);
|
||||
match (&upstream_api, client_api) {
|
||||
fn try_from((bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
match (upstream_api, client_api) {
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?;
|
||||
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(resp))
|
||||
|
|
@ -229,73 +273,81 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderStreamResponseTyp
|
|||
}
|
||||
|
||||
// TryFrom implementation to convert raw bytes to SseEvent with parsed provider response
|
||||
impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for SseEvent {
|
||||
impl TryFrom<(&SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result<Self, Self::Error> {
|
||||
// Convert bytes to string
|
||||
let body_str = std::str::from_utf8(bytes)?;
|
||||
let mut sse_event: SseEvent = body_str.parse()?;
|
||||
fn try_from((sse_event, client_api, upstream_api): (&SseEvent, &SupportedAPIs, &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
// Create a new transformed event based on the original
|
||||
let mut transformed_event = sse_event.clone();
|
||||
|
||||
// If not [DONE], parse the data as a provider stream response (business logic layer)
|
||||
if !sse_event.is_done() {
|
||||
// Use the new ProviderStreamResponseType::try_from to parse the JSON data
|
||||
let provider_response = ProviderStreamResponseType::try_from((sse_event.data.as_bytes(), client_api, provider_id))?;
|
||||
sse_event.provider_stream_response = Some(provider_response);
|
||||
// If not [DONE] and has data, parse the data as a provider stream response (business logic layer)
|
||||
if !transformed_event.is_done() && sse_event.data.is_some() {
|
||||
let data_str = sse_event.data.as_ref().unwrap();
|
||||
let data_bytes = data_str.as_bytes();
|
||||
let transformed_response = ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?;
|
||||
let transformed_json = serde_json::to_string(&transformed_response)?;
|
||||
transformed_event.raw_line_transformed = format!("data: {}\n\n", transformed_json);
|
||||
transformed_event.provider_stream_response = Some(transformed_response);
|
||||
}
|
||||
|
||||
Ok(sse_event)
|
||||
}
|
||||
}
|
||||
|
||||
// TryFrom implementation for transforming SseEvent between API formats
|
||||
impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((mut event, upstream_api, client_api): (SseEvent, &SupportedAPIs, &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
// If APIs are the same, no transformation needed
|
||||
if std::mem::discriminant(upstream_api) == std::mem::discriminant(client_api) {
|
||||
return Ok(event);
|
||||
}
|
||||
|
||||
// Handle [DONE] events - they don't need transformation
|
||||
if event.is_done() {
|
||||
return Ok(event);
|
||||
}
|
||||
|
||||
// Transform the data field based on API conversion
|
||||
let transformed_data = match (upstream_api, client_api) {
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
// Parse OpenAI response and convert to Anthropic
|
||||
let openai_response: crate::apis::openai::ChatCompletionsStreamResponse =
|
||||
serde_json::from_str(&event.data)?;
|
||||
let anthropic_response: crate::apis::anthropic::MessagesStreamEvent =
|
||||
openai_response.try_into()?;
|
||||
serde_json::to_string(&anthropic_response)?
|
||||
match (client_api, upstream_api) {
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
// No transformation needed
|
||||
}
|
||||
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
// No transformation needed
|
||||
}
|
||||
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
// Parse Anthropic response and convert to OpenAI
|
||||
let anthropic_response: crate::apis::anthropic::MessagesStreamEvent =
|
||||
serde_json::from_str(&event.data)?;
|
||||
let openai_response: crate::apis::openai::ChatCompletionsStreamResponse =
|
||||
anthropic_response.try_into()?;
|
||||
serde_json::to_string(&openai_response)?
|
||||
if let Some(provider_response) = &transformed_event.provider_stream_response {
|
||||
if let Some(event_type) = provider_response.event_type() {
|
||||
// This ensures the required Anthropic sequence: MessageStart → ContentBlockStart → ContentBlockDelta(s)
|
||||
if event_type == "message_start" {
|
||||
let content_block_start_json = serde_json::json!({
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "text",
|
||||
"text": ""
|
||||
}
|
||||
});
|
||||
// Format as proper SSE: MessageStart first, then ContentBlockStart
|
||||
transformed_event.raw_line_transformed = format!(
|
||||
"event: {}\n{}\nevent: content_block_start\ndata: {}\n\n",
|
||||
event_type,
|
||||
transformed_event.raw_line_transformed,
|
||||
content_block_start_json,
|
||||
);
|
||||
} else if event_type == "message_delta" {
|
||||
let content_block_stop_json = serde_json::json!({
|
||||
"type": "content_block_stop",
|
||||
"index": 0
|
||||
});
|
||||
// Format as proper SSE: ContentBlockStop first, then MessageDelta
|
||||
transformed_event.raw_line_transformed = format!(
|
||||
"event: content_block_stop\ndata: {}\n\nevent: {}\n{}",
|
||||
content_block_stop_json,
|
||||
event_type,
|
||||
transformed_event.raw_line_transformed
|
||||
);
|
||||
} else {
|
||||
transformed_event.raw_line_transformed = format!("event: {}\n{}", event_type, transformed_event.raw_line_transformed);
|
||||
}
|
||||
}
|
||||
// If event_type is None, we just keep the data line as-is without an event line
|
||||
// This handles cases where the transformation might not produce a valid event type
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(format!("Unsupported API transformation: {:?} -> {:?}", upstream_api, client_api).into());
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
if sse_event.is_event_only() && sse_event.event.is_some() {
|
||||
transformed_event.raw_line_transformed = format!("\n"); // suppress the event upstream for OpenAI
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Update the event with transformed data and reconstruct raw_line
|
||||
event.data = transformed_data;
|
||||
event.raw_line = format!("data: {}", event.data);
|
||||
|
||||
Ok(event)
|
||||
Ok(transformed_event)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SseParseError {
|
||||
pub message: String,
|
||||
|
|
@ -353,13 +405,16 @@ where
|
|||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
for line in &mut self.lines {
|
||||
if let Ok(event) = line.as_ref().parse::<SseEvent>() {
|
||||
// Check if this is the [DONE] marker - if so, end the stream
|
||||
if event.is_done() {
|
||||
let line_str = line.as_ref();
|
||||
|
||||
// Try to parse as either data: or event: line
|
||||
if let Ok(event) = line_str.parse::<SseEvent>() {
|
||||
// For data: lines, check if this is the [DONE] marker - if so, end the stream
|
||||
if event.data.is_some() && event.is_done() {
|
||||
return None;
|
||||
}
|
||||
// Skip events that should be filtered at the transport layer
|
||||
if event.should_skip() {
|
||||
// For data: lines, skip events that should be filtered at the transport layer
|
||||
if event.data.is_some() && event.should_skip() {
|
||||
continue;
|
||||
}
|
||||
return Some(event);
|
||||
|
|
@ -525,7 +580,7 @@ mod tests {
|
|||
let event: Result<SseEvent, _> = line.parse();
|
||||
assert!(event.is_ok());
|
||||
let event = event.unwrap();
|
||||
assert_eq!(event.data, r#"{"id":"test","object":"chat.completion.chunk"}"#);
|
||||
assert_eq!(event.data, Some(r#"{"id":"test","object":"chat.completion.chunk"}"#.to_string()));
|
||||
|
||||
// Test conversion back to line using Display trait
|
||||
let wire_format = event.to_string();
|
||||
|
|
@ -536,7 +591,7 @@ mod tests {
|
|||
let done_result: Result<SseEvent, _> = done_line.parse();
|
||||
assert!(done_result.is_ok());
|
||||
let done_event = done_result.unwrap();
|
||||
assert_eq!(done_event.data, "[DONE]");
|
||||
assert_eq!(done_event.data, Some("[DONE]".to_string()));
|
||||
assert!(done_event.is_done()); // Test the helper method
|
||||
|
||||
// Test non-DONE event
|
||||
|
|
@ -557,12 +612,16 @@ mod tests {
|
|||
fn test_sse_event_serde() {
|
||||
// Test serialization and deserialization with serde
|
||||
let event = SseEvent {
|
||||
data: r#"{"id":"test","object":"chat.completion.chunk"}"#.to_string(),
|
||||
data: Some(r#"{"id":"test","object":"chat.completion.chunk"}"#.to_string()),
|
||||
event: None,
|
||||
raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"}
|
||||
|
||||
"#.to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
raw_line_transformed: r#"data: {"id":"test","object":"chat.completion.chunk"}
|
||||
|
||||
"#.to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
|
||||
// Test JSON serialization - raw_line should be skipped
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
|
|
@ -583,8 +642,10 @@ mod tests {
|
|||
fn test_sse_event_should_skip() {
|
||||
// Test ping message should be skipped
|
||||
let ping_event = SseEvent {
|
||||
data: r#"{"type": "ping"}"#.to_string(),
|
||||
data: Some(r#"{"type": "ping"}"#.to_string()),
|
||||
event: None,
|
||||
raw_line: r#"data: {"type": "ping"}"#.to_string(),
|
||||
raw_line_transformed: r#"data: {"type": "ping"}"#.to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
assert!(ping_event.should_skip());
|
||||
|
|
@ -592,8 +653,10 @@ mod tests {
|
|||
|
||||
// Test normal event should not be skipped
|
||||
let normal_event = SseEvent {
|
||||
data: r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
|
||||
data: Some(r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string()),
|
||||
event: Some("content_block_delta".to_string()),
|
||||
raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
|
||||
raw_line_transformed: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
assert!(!normal_event.should_skip());
|
||||
|
|
@ -601,8 +664,10 @@ mod tests {
|
|||
|
||||
// Test [DONE] event should not be skipped (but is handled separately)
|
||||
let done_event = SseEvent {
|
||||
data: "[DONE]".to_string(),
|
||||
data: Some("[DONE]".to_string()),
|
||||
event: None,
|
||||
raw_line: "data: [DONE]".to_string(),
|
||||
raw_line_transformed: "data: [DONE]".to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
assert!(!done_event.should_skip());
|
||||
|
|
@ -624,15 +689,82 @@ mod tests {
|
|||
|
||||
// First event should be msg1 (ping filtered out)
|
||||
let event1 = iter.next().unwrap();
|
||||
assert!(event1.data.contains("msg1"));
|
||||
assert!(event1.data.as_ref().unwrap().contains("msg1"));
|
||||
assert!(!event1.should_skip());
|
||||
|
||||
// Second event should be msg2 (ping filtered out)
|
||||
let event2 = iter.next().unwrap();
|
||||
assert!(event2.data.contains("msg2"));
|
||||
assert!(event2.data.as_ref().unwrap().contains("msg2"));
|
||||
assert!(!event2.should_skip());
|
||||
|
||||
// Iterator should end at [DONE] (no more events)
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_stream_iter_handles_anthropic_events() {
|
||||
// Create test data with Anthropic-style event/data pairs
|
||||
let test_lines = vec![
|
||||
"event: message_start".to_string(),
|
||||
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\"}}".to_string(),
|
||||
"event: content_block_delta".to_string(),
|
||||
"data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}".to_string(),
|
||||
"data: [DONE]".to_string(),
|
||||
];
|
||||
|
||||
let mut iter = SseStreamIter::new(test_lines.into_iter());
|
||||
|
||||
// First event should be the event: line
|
||||
let event1 = iter.next().unwrap();
|
||||
assert!(event1.is_event_only());
|
||||
assert_eq!(event1.event, Some("message_start".to_string()));
|
||||
assert_eq!(event1.data, None);
|
||||
|
||||
// Second event should be the data: line
|
||||
let event2 = iter.next().unwrap();
|
||||
assert!(!event2.is_event_only());
|
||||
assert_eq!(event2.event, None);
|
||||
assert!(event2.data.as_ref().unwrap().contains("message_start"));
|
||||
|
||||
// Third event should be another event: line
|
||||
let event3 = iter.next().unwrap();
|
||||
assert!(event3.is_event_only());
|
||||
assert_eq!(event3.event, Some("content_block_delta".to_string()));
|
||||
|
||||
// Fourth event should be the content delta data
|
||||
let event4 = iter.next().unwrap();
|
||||
assert!(!event4.is_event_only());
|
||||
assert!(event4.data.as_ref().unwrap().contains("Hello"));
|
||||
|
||||
// Iterator should end at [DONE]
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_stream_response_event_type() {
|
||||
use crate::apis::anthropic::{MessagesStreamEvent, MessagesContentDelta};
|
||||
use crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
|
||||
// Test Anthropic event type
|
||||
let anthropic_event = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::TextDelta { text: "Hello".to_string() },
|
||||
};
|
||||
let provider_type = ProviderStreamResponseType::MessagesStreamEvent(anthropic_event);
|
||||
assert_eq!(provider_type.event_type(), Some("content_block_delta"));
|
||||
|
||||
// Test OpenAI event type (should be None)
|
||||
let openai_event = ChatCompletionsStreamResponse {
|
||||
id: "test".to_string(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created: 123456789,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
let provider_type = ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_event);
|
||||
assert_eq!(provider_type.event_type(), None);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -144,13 +144,17 @@ impl StreamContext {
|
|||
match self.resolved_api.as_ref() {
|
||||
Some(SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
// Anthropic API requires x-api-key and anthropic-version headers
|
||||
// Remove any existing Authorization header since Anthropic doesn't use it
|
||||
self.set_http_request_header("authorization", None);
|
||||
self.set_http_request_header("x-api-key", Some(llm_provider_api_key_value));
|
||||
self.set_http_request_header("anthropic-version", Some("2023-06-01"));
|
||||
}
|
||||
Some(SupportedAPIs::OpenAIChatCompletions(_)) | None => {
|
||||
// OpenAI and default: use Authorization Bearer token
|
||||
// Remove any existing x-api-key header since OpenAI doesn't use it
|
||||
self.set_http_request_header("x-api-key", None);
|
||||
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
|
||||
self.set_http_request_header("Authorization", Some(&authorization_header_value));
|
||||
self.set_http_request_header("authorization", Some(&authorization_header_value));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -426,7 +430,7 @@ impl StreamContext {
|
|||
for sse_event in sse_iter {
|
||||
// Transform event if upstream API != client API
|
||||
let transformed_event: SseEvent =
|
||||
match SseEvent::try_from((sse_event, &upstream_api, &client_api)) {
|
||||
match SseEvent::try_from((&sse_event, &client_api, &upstream_api)) {
|
||||
Ok(event) => event,
|
||||
Err(e) => {
|
||||
warn!("Failed to transform SSE event: {}", e);
|
||||
|
|
@ -436,7 +440,7 @@ impl StreamContext {
|
|||
|
||||
// Extract ProviderStreamResponse for processing (token counting, etc.)
|
||||
if !transformed_event.is_done() {
|
||||
match transformed_event.to_provider_stream_response(&client_api) {
|
||||
match transformed_event.provider_response() {
|
||||
Ok(provider_response) => {
|
||||
self.record_ttft_if_needed();
|
||||
|
||||
|
|
@ -910,6 +914,12 @@ impl HttpContext for StreamContext {
|
|||
if self.streaming_response {
|
||||
match self.handle_streaming_response(&body, provider_id) {
|
||||
Ok(serialized_body) => {
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_TRANSFORMED_RESPONSE: body_size={} content={}",
|
||||
self.request_identifier(),
|
||||
body.len(),
|
||||
String::from_utf8_lossy(&serialized_body)
|
||||
);
|
||||
self.set_http_response_body(0, body_size, &serialized_body);
|
||||
}
|
||||
Err(action) => return action,
|
||||
|
|
|
|||
|
|
@ -380,6 +380,65 @@ def test_claude_v1_messages_api():
|
|||
assert message.content[0].text == "Hello from Claude!"
|
||||
|
||||
|
||||
def test_claude_v1_messages_api_streaming():
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
|
||||
client = anthropic.Anthropic(api_key="test-key", base_url=base_url)
|
||||
|
||||
with client.messages.stream(
|
||||
model="claude-sonnet-4-20250514",
|
||||
max_tokens=50,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, please respond with exactly: Hello from Claude!",
|
||||
}
|
||||
],
|
||||
) as stream:
|
||||
# This yields only text deltas in order
|
||||
pieces = [t for t in stream.text_stream]
|
||||
full_text = "".join(pieces)
|
||||
|
||||
# You can also get the fully-assembled Message object
|
||||
final = stream.get_final_message()
|
||||
# A safe way to reassemble text from the content blocks:
|
||||
final_text = "".join(b.text for b in final.content if b.type == "text")
|
||||
|
||||
assert full_text == "Hello from Claude!"
|
||||
assert final_text == "Hello from Claude!"
|
||||
|
||||
|
||||
def test_anthropic_client_with_openai_model_streaming():
|
||||
"""Test Anthropic client using /v1/messages API with OpenAI model (gpt-4o-mini)
|
||||
This tests the transformation: OpenAI upstream -> Anthropic client format with proper event lines
|
||||
"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
|
||||
client = anthropic.Anthropic(api_key="test-key", base_url=base_url)
|
||||
|
||||
with client.messages.stream(
|
||||
model="gpt-4o-mini", # OpenAI model via Anthropic client
|
||||
max_tokens=50,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, please respond with exactly: Hello from GPT-4o-mini via Anthropic!",
|
||||
}
|
||||
],
|
||||
) as stream:
|
||||
# This yields only text deltas in order
|
||||
pieces = [t for t in stream.text_stream]
|
||||
full_text = "".join(pieces)
|
||||
|
||||
# You can also get the fully-assembled Message object
|
||||
final = stream.get_final_message()
|
||||
# A safe way to reassemble text from the content blocks:
|
||||
final_text = "".join(b.text for b in final.content if b.type == "text")
|
||||
|
||||
assert full_text == "Hello from GPT-4o-mini via Anthropic!"
|
||||
assert final_text == "Hello from GPT-4o-mini via Anthropic!"
|
||||
|
||||
|
||||
def test_openai_gpt4o_mini_v1_messages_api():
|
||||
"""Test OpenAI GPT-4o-mini using /v1/chat/completions API through llm_gateway (port 12000)"""
|
||||
# Get the base URL from the LLM gateway endpoint
|
||||
|
|
@ -402,3 +461,72 @@ def test_openai_gpt4o_mini_v1_messages_api():
|
|||
)
|
||||
|
||||
assert completion.choices[0].message.content == "Hello from GPT-4o-mini!"
|
||||
|
||||
|
||||
def test_openai_gpt4o_mini_v1_messages_api_streaming():
|
||||
"""Test OpenAI GPT-4o-mini using /v1/chat/completions API with streaming through llm_gateway (port 12000)"""
|
||||
# Get the base URL from the LLM gateway endpoint
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
|
||||
client = openai.OpenAI(
|
||||
api_key="test-key", # Dummy key for testing
|
||||
base_url=f"{base_url}/v1", # OpenAI needs /v1 suffix in base_url
|
||||
)
|
||||
|
||||
stream = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
max_tokens=50,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, please respond with exactly: Hello from GPT-4o-mini!",
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Collect all the streaming chunks
|
||||
content_chunks = []
|
||||
for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
content_chunks.append(chunk.choices[0].delta.content)
|
||||
|
||||
# Reconstruct the full message
|
||||
full_content = "".join(content_chunks)
|
||||
assert full_content == "Hello from GPT-4o-mini!"
|
||||
|
||||
|
||||
def test_openai_client_with_claude_model_streaming():
|
||||
"""Test OpenAI client using /v1/chat/completions API with Claude model (claude-sonnet-4-20250514)
|
||||
This tests the transformation: Anthropic upstream -> OpenAI client format with proper chunk handling
|
||||
"""
|
||||
# Get the base URL from the LLM gateway endpoint
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
|
||||
client = openai.OpenAI(
|
||||
api_key="test-key", # Dummy key for testing
|
||||
base_url=f"{base_url}/v1", # OpenAI needs /v1 suffix in base_url
|
||||
)
|
||||
|
||||
stream = client.chat.completions.create(
|
||||
model="claude-sonnet-4-20250514", # Claude model via OpenAI client
|
||||
max_tokens=50,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you? ALWAYS RESPOND WITH:I appreciate the request, but I should clarify that I'm Claude, made by Anthropic, not OpenAI. I don't want to create confusion about my origins.",
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
# Collect all the streaming chunks
|
||||
content_chunks = []
|
||||
for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
content_chunks.append(chunk.choices[0].delta.content)
|
||||
|
||||
# Reconstruct the full message
|
||||
full_content = "".join(content_chunks)
|
||||
assert full_content is not None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue