mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
Add support for v1/responses API (#622)
* making first commit. still need to work on streaming respones * making first commit. still need to work on streaming respones * stream buffer implementation with tests * adding grok API keys to workflow * fixed changes based on code review * adding support for bedrock models * fixed issues with translation to claude code --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local>
This commit is contained in:
parent
b01a81927d
commit
a448c6e9cb
38 changed files with 7015 additions and 2955 deletions
1
.github/workflows/e2e_tests.yml
vendored
1
.github/workflows/e2e_tests.yml
vendored
|
|
@ -59,6 +59,7 @@ jobs:
|
|||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
|
||||
AWS_BEARER_TOKEN_BEDROCK: ${{ secrets.AWS_BEARER_TOKEN_BEDROCK }}
|
||||
GROK_API_KEY : ${{ secrets.GROK_API_KEY }}
|
||||
run: |
|
||||
python -mvenv venv
|
||||
source venv/bin/activate && cd tests/e2e && bash run_e2e_tests.sh
|
||||
|
|
|
|||
2
crates/Cargo.lock
generated
2
crates/Cargo.lock
generated
|
|
@ -912,10 +912,12 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"aws-smithy-eventstream",
|
||||
"bytes",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"thiserror 2.0.12",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use common::configuration::{ModelAlias, ModelUsagePreference};
|
|||
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
|
|
@ -39,7 +39,7 @@ pub async fn router_chat(
|
|||
|
||||
let mut client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
&SupportedAPIs::from_endpoint(request_path.as_str()).unwrap(),
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
|
|
@ -58,7 +58,7 @@ pub async fn router_chat(
|
|||
let resolved_model = if let Some(model_aliases) = model_aliases.as_ref() {
|
||||
if let Some(model_alias) = model_aliases.get(&model_from_request) {
|
||||
debug!(
|
||||
"Model Alias: 'From {}' -> 'To{}'",
|
||||
"Model Alias: 'From {}' -> 'To {}'",
|
||||
model_from_request, model_alias.target
|
||||
);
|
||||
model_alias.target.clone()
|
||||
|
|
@ -91,10 +91,11 @@ pub async fn router_chat(
|
|||
Ok(
|
||||
ProviderRequestType::MessagesRequest(_)
|
||||
| ProviderRequestType::BedrockConverse(_)
|
||||
| ProviderRequestType::BedrockConverseStream(_),
|
||||
| ProviderRequestType::BedrockConverseStream(_)
|
||||
| ProviderRequestType::ResponsesAPIRequest(_),
|
||||
) => {
|
||||
// This should not happen after conversion to OpenAI format
|
||||
warn!("Unexpected: got MessagesRequest after converting to OpenAI format");
|
||||
warn!("Unexpected: got non-ChatCompletions request after converting to OpenAI format");
|
||||
let err_msg = "Request conversion failed".to_string();
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ use brightstaff::router::llm_router::RouterService;
|
|||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::Configuration;
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH};
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::server::conn::http1;
|
||||
|
|
@ -123,7 +123,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, req.uri().path());
|
||||
router_chat(req, router_service, fully_qualified_url, model_aliases)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ pub const MESSAGES_KEY: &str = "messages";
|
|||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const ARCH_IS_STREAMING_HEADER: &str = "x-arch-streaming-request";
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const OPENAI_RESPONSES_API_PATH: &str = "/v1/responses";
|
||||
pub const MESSAGES_PATH: &str = "/v1/messages";
|
||||
pub const HEALTHZ_PATH: &str = "/healthz";
|
||||
pub const X_ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
|
|
|
|||
|
|
@ -10,3 +10,5 @@ serde_with = {version = "3.12.0", features = ["base64"]}
|
|||
thiserror = "2.0.12"
|
||||
aws-smithy-eventstream = "0.60"
|
||||
bytes = "1.10"
|
||||
uuid = { version = "1.11", features = ["v4"] }
|
||||
log = "0.4"
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use thiserror::Error;
|
|||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::ProviderStreamResponse;
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
|
||||
// ============================================================================
|
||||
// AMAZON BEDROCK CONVERSE API ENUMERATION
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ use std::collections::HashMap;
|
|||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
|
||||
use crate::providers::response::ProviderResponse;
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::MESSAGES_PATH;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
pub mod amazon_bedrock;
|
||||
pub mod amazon_bedrock_binary_frame;
|
||||
pub mod anthropic;
|
||||
pub mod openai;
|
||||
pub mod sse;
|
||||
pub mod openai_responses;
|
||||
pub mod streaming_shapes;
|
||||
|
||||
// Explicit exports to avoid naming conflicts
|
||||
pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest};
|
||||
|
|
@ -88,8 +88,9 @@ mod tests {
|
|||
fn test_all_variants_method() {
|
||||
// Test that all_variants returns the expected variants
|
||||
let openai_variants = OpenAIApi::all_variants();
|
||||
assert_eq!(openai_variants.len(), 1);
|
||||
assert_eq!(openai_variants.len(), 2);
|
||||
assert!(openai_variants.contains(&OpenAIApi::ChatCompletions));
|
||||
assert!(openai_variants.contains(&OpenAIApi::Responses));
|
||||
|
||||
let anthropic_variants = AnthropicApi::all_variants();
|
||||
assert_eq!(anthropic_variants.len(), 1);
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@ use thiserror::Error;
|
|||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
|
||||
use crate::providers::response::{ProviderResponse, TokenUsage};
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::CHAT_COMPLETIONS_PATH;
|
||||
use crate::{CHAT_COMPLETIONS_PATH, OPENAI_RESPONSES_API_PATH};
|
||||
|
||||
// ============================================================================
|
||||
// OPENAI API ENUMERATION
|
||||
|
|
@ -19,6 +20,7 @@ use crate::CHAT_COMPLETIONS_PATH;
|
|||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum OpenAIApi {
|
||||
ChatCompletions,
|
||||
Responses,
|
||||
// Future APIs can be added here:
|
||||
// Embeddings,
|
||||
// FineTuning,
|
||||
|
|
@ -29,12 +31,14 @@ impl ApiDefinition for OpenAIApi {
|
|||
fn endpoint(&self) -> &'static str {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => CHAT_COMPLETIONS_PATH,
|
||||
OpenAIApi::Responses => OPENAI_RESPONSES_API_PATH,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
match endpoint {
|
||||
CHAT_COMPLETIONS_PATH => Some(OpenAIApi::ChatCompletions),
|
||||
OPENAI_RESPONSES_API_PATH => Some(OpenAIApi::Responses),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
@ -42,23 +46,26 @@ impl ApiDefinition for OpenAIApi {
|
|||
fn supports_streaming(&self) -> bool {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => true,
|
||||
OpenAIApi::Responses => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => true,
|
||||
OpenAIApi::Responses => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_vision(&self) -> bool {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => true,
|
||||
OpenAIApi::Responses => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn all_variants() -> Vec<Self> {
|
||||
vec![OpenAIApi::ChatCompletions]
|
||||
vec![OpenAIApi::ChatCompletions, OpenAIApi::Responses]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1077,8 +1084,9 @@ mod tests {
|
|||
|
||||
// Test all_variants
|
||||
let all_variants = OpenAIApi::all_variants();
|
||||
assert_eq!(all_variants.len(), 1);
|
||||
assert_eq!(all_variants[0], OpenAIApi::ChatCompletions);
|
||||
assert_eq!(all_variants.len(), 2);
|
||||
assert!(all_variants.contains(&OpenAIApi::ChatCompletions));
|
||||
assert!(all_variants.contains(&OpenAIApi::Responses));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
1386
crates/hermesllm/src/apis/openai_responses.rs
Normal file
1386
crates/hermesllm/src/apis/openai_responses.rs
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,7 +1,6 @@
|
|||
use aws_smithy_eventstream::frame::DecodedFrame;
|
||||
use aws_smithy_eventstream::frame::MessageFrameDecoder;
|
||||
use bytes::Buf;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// AWS Event Stream frame decoder wrapper
|
||||
pub struct BedrockBinaryFrameDecoder<B>
|
||||
|
|
@ -10,7 +9,6 @@ where
|
|||
{
|
||||
decoder: MessageFrameDecoder,
|
||||
buffer: B,
|
||||
content_block_start_indices: HashSet<i32>,
|
||||
}
|
||||
|
||||
impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
|
||||
|
|
@ -20,7 +18,6 @@ impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
|
|||
Self {
|
||||
decoder: MessageFrameDecoder::new(),
|
||||
buffer,
|
||||
content_block_start_indices: std::collections::HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -33,7 +30,6 @@ where
|
|||
Self {
|
||||
decoder: MessageFrameDecoder::new(),
|
||||
buffer,
|
||||
content_block_start_indices: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -52,14 +48,4 @@ where
|
|||
pub fn has_remaining(&self) -> bool {
|
||||
self.buffer.has_remaining()
|
||||
}
|
||||
|
||||
/// Check if a content_block_start event has been sent for the given index
|
||||
pub fn has_content_block_start_been_sent(&self, index: i32) -> bool {
|
||||
self.content_block_start_indices.contains(&index)
|
||||
}
|
||||
|
||||
/// Mark that a content_block_start event has been sent for the given index
|
||||
pub fn set_content_block_start_sent(&mut self, index: i32) {
|
||||
self.content_block_start_indices.insert(index);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,507 @@
|
|||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// SSE Stream Buffer for Anthropic Messages API streaming.
|
||||
///
|
||||
/// This buffer manages the wire format for Anthropic Messages API streaming,
|
||||
/// handling the specific event sequencing requirements:
|
||||
/// - MessageStart → ContentBlockStart → ContentBlockDelta(s) → ContentBlockStop → MessageDelta → MessageStop
|
||||
///
|
||||
/// When converting from OpenAI to Anthropic format, this buffer injects the required
|
||||
/// ContentBlockStart and ContentBlockStop events to maintain proper Anthropic protocol.
|
||||
pub struct AnthropicMessagesStreamBuffer {
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
|
||||
/// Track if we've seen a message_start event
|
||||
message_started: bool,
|
||||
|
||||
/// Track content block indices that have received ContentBlockStart events
|
||||
content_block_start_indices: HashSet<i32>,
|
||||
|
||||
/// Track if we need to inject ContentBlockStop before message_delta
|
||||
needs_content_block_stop: bool,
|
||||
|
||||
/// Track if we've seen a MessageDelta (so we need to send MessageStop at the end)
|
||||
seen_message_delta: bool,
|
||||
|
||||
/// Model name to use when generating message_start events
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
impl AnthropicMessagesStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffered_events: Vec::new(),
|
||||
message_started: false,
|
||||
content_block_start_indices: HashSet::new(),
|
||||
needs_content_block_stop: false,
|
||||
seen_message_delta: false,
|
||||
model: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a content_block_start event has been sent for the given index
|
||||
fn has_content_block_start_been_sent(&self, index: i32) -> bool {
|
||||
self.content_block_start_indices.contains(&index)
|
||||
}
|
||||
|
||||
/// Mark that a content_block_start event has been sent for the given index
|
||||
fn set_content_block_start_sent(&mut self, index: i32) {
|
||||
self.content_block_start_indices.insert(index);
|
||||
}
|
||||
|
||||
/// Helper to create and format a ContentBlockStart SSE event
|
||||
fn create_content_block_start_event() -> SseEvent {
|
||||
let content_block_start = MessagesStreamEvent::ContentBlockStart {
|
||||
index: 0,
|
||||
content_block: crate::apis::anthropic::MessagesContentBlock::Text {
|
||||
text: String::new(),
|
||||
cache_control: None,
|
||||
},
|
||||
};
|
||||
let sse_string: String = content_block_start.into();
|
||||
|
||||
SseEvent {
|
||||
data: None,
|
||||
event: Some("content_block_start".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create and format a MessageStart SSE event
|
||||
fn create_message_start_event(model: &str) -> SseEvent {
|
||||
let message_start = MessagesStreamEvent::MessageStart {
|
||||
message: crate::apis::anthropic::MessagesStreamMessage {
|
||||
id: format!("msg_{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
|
||||
obj_type: "message".to_string(),
|
||||
role: crate::apis::anthropic::MessagesRole::Assistant,
|
||||
content: vec![],
|
||||
model: model.to_string(),
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: crate::apis::anthropic::MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
},
|
||||
};
|
||||
let sse_string: String = message_start.into();
|
||||
|
||||
SseEvent {
|
||||
data: None,
|
||||
event: Some("message_start".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create and format a ContentBlockStop SSE event
|
||||
fn create_content_block_stop_event() -> SseEvent {
|
||||
let content_block_stop = MessagesStreamEvent::ContentBlockStop { index: 0 };
|
||||
let sse_string: String = content_block_stop.into();
|
||||
|
||||
SseEvent {
|
||||
data: None,
|
||||
event: Some("content_block_stop".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
// Skip ping messages
|
||||
if event.should_skip() {
|
||||
return;
|
||||
}
|
||||
|
||||
// FIRST: Try to extract model name from the raw event data before transformation
|
||||
// The provider_stream_response has already been transformed to Anthropic format,
|
||||
// so we need to extract the model from the original raw data if available
|
||||
if self.model.is_none() {
|
||||
if let Some(data) = &event.data {
|
||||
// Try to parse as JSON and extract model field
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
|
||||
if let Some(model) = json.get("model").and_then(|m| m.as_str()) {
|
||||
self.model = Some(model.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Match directly on the provider response type to handle event processing
|
||||
// We match on a reference first to determine the type, then move the event
|
||||
match &event.provider_stream_response {
|
||||
Some(ProviderStreamResponseType::MessagesStreamEvent(evt)) => {
|
||||
match evt {
|
||||
MessagesStreamEvent::MessageStart { .. } => {
|
||||
// Add the message_start event
|
||||
self.buffered_events.push(event);
|
||||
self.message_started = true;
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockStart { index, .. } => {
|
||||
let index = *index as i32;
|
||||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
|
||||
// Add the content_block_start event (from tool calls or other sources)
|
||||
self.buffered_events.push(event);
|
||||
self.set_content_block_start_sent(index);
|
||||
self.needs_content_block_stop = true;
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockDelta { index, .. } => {
|
||||
let index = *index as i32;
|
||||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
|
||||
// Check if ContentBlockStart was sent for this index
|
||||
if !self.has_content_block_start_been_sent(index) {
|
||||
// Inject ContentBlockStart before delta
|
||||
let content_block_start = AnthropicMessagesStreamBuffer::create_content_block_start_event();
|
||||
self.buffered_events.push(content_block_start);
|
||||
self.set_content_block_start_sent(index);
|
||||
self.needs_content_block_stop = true;
|
||||
}
|
||||
|
||||
// Content deltas are between ContentBlockStart and ContentBlockStop
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
MessagesStreamEvent::MessageDelta { usage, .. } => {
|
||||
// Inject ContentBlockStop before message_delta
|
||||
if self.needs_content_block_stop {
|
||||
let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event();
|
||||
self.buffered_events.push(content_block_stop);
|
||||
self.needs_content_block_stop = false;
|
||||
}
|
||||
|
||||
// Check if the last event was also a MessageDelta - if so, merge them
|
||||
// This handles Bedrock's split of stop_reason (MessageStop) and usage (Metadata)
|
||||
if let Some(last_event) = self.buffered_events.last_mut() {
|
||||
if let Some(ProviderStreamResponseType::MessagesStreamEvent(
|
||||
MessagesStreamEvent::MessageDelta {
|
||||
usage: last_usage,
|
||||
..
|
||||
}
|
||||
)) = &mut last_event.provider_stream_response {
|
||||
// Merge: take stop_reason from first, usage from second (if non-zero)
|
||||
if usage.input_tokens > 0 || usage.output_tokens > 0 {
|
||||
*last_usage = usage.clone();
|
||||
}
|
||||
// Mark that we've seen MessageDelta (need to send MessageStop later)
|
||||
self.seen_message_delta = true;
|
||||
// Don't push the new event, we've merged it
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// No previous MessageDelta to merge with, add this one
|
||||
self.buffered_events.push(event);
|
||||
self.seen_message_delta = true;
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => {
|
||||
// ContentBlockStop received from upstream (e.g., Bedrock)
|
||||
// Clear the flag so we don't inject another one
|
||||
self.needs_content_block_stop = false;
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
MessagesStreamEvent::MessageStop => {
|
||||
// MessageStop received from upstream (e.g., OpenAI via [DONE])
|
||||
// Clear the flag so we don't inject another one
|
||||
self.seen_message_delta = false;
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
_ => {
|
||||
// Other Anthropic event types (Ping, etc.), just accumulate
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Non-Anthropic events or events without provider_stream_response, just accumulate
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
// Convert all accumulated events to bytes and clear buffer
|
||||
// NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta
|
||||
// or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming.
|
||||
|
||||
// Inject MessageStop after MessageDelta if we've seen one
|
||||
// This completes the Anthropic Messages API event sequence
|
||||
if self.seen_message_delta {
|
||||
let message_stop = MessagesStreamEvent::MessageStop;
|
||||
let sse_string: String = message_stop.into();
|
||||
let message_stop_event = SseEvent {
|
||||
data: None,
|
||||
event: Some("message_stop".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
};
|
||||
self.buffered_events.push(message_stop_event);
|
||||
self.seen_message_delta = false;
|
||||
}
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
let event_bytes: Vec<u8> = event.into();
|
||||
buffer.extend_from_slice(&event_bytes);
|
||||
}
|
||||
buffer
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::streaming_shapes::sse::SseStreamIter;
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_complete_transformation() {
|
||||
// OpenAI ChatCompletions input that will be transformed to Anthropic Messages API
|
||||
let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 1: OpenAI → Anthropic Messages API Complete Transformation");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (OpenAI ChatCompletions):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation (client wants Anthropic, upstream is OpenAI)
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Parse events and apply transformation
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("event: message_start"), "Should have message_start");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
|
||||
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert_eq!(delta_count, 2, "Should have exactly 2 content_block_delta events");
|
||||
|
||||
// Verify both pieces of content are present
|
||||
assert!(output.contains("\"text\":\"Hello\""), "Should have first content delta 'Hello'");
|
||||
assert!(output.contains("\"text\":\" world\""), "Should have second content delta ' world'");
|
||||
|
||||
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
|
||||
assert!(output.contains("event: message_delta"), "Should have message_delta");
|
||||
assert!(output.contains("event: message_stop"), "Should have message_stop");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Complete transformation: OpenAI ChatCompletions → Anthropic Messages API");
|
||||
println!("✓ Injected lifecycle events: message_start, content_block_start, content_block_stop");
|
||||
println!("✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)", delta_count);
|
||||
println!("✓ Complete stream with message_stop");
|
||||
println!("✓ Proper Anthropic protocol sequencing\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_partial_transformation() {
|
||||
// Partial OpenAI ChatCompletions stream - no [DONE]
|
||||
let raw_input = r#"data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"The weather"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" in San Francisco"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" is"},"finish_reason":null}]}"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 2: OpenAI → Anthropic Partial Transformation (NO [DONE])");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (OpenAI ChatCompletions - NO [DONE]):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Parse and transform events
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("event: message_start"), "Should have message_start");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
|
||||
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert_eq!(delta_count, 3, "Should have exactly 3 content_block_delta events");
|
||||
|
||||
// Verify all three pieces of content are present
|
||||
assert!(output.contains("\"text\":\"The weather\""), "Should have first content delta");
|
||||
assert!(output.contains("\"text\":\" in San Francisco\""), "Should have second content delta");
|
||||
assert!(output.contains("\"text\":\" is\""), "Should have third content delta");
|
||||
|
||||
// For partial streams (no finish_reason, no [DONE]), we do NOT inject content_block_stop
|
||||
// because the stream may continue. This is correct behavior - only inject lifecycle events
|
||||
// when we have explicit signals from upstream (finish_reason, [DONE], etc.)
|
||||
assert!(!output.contains("event: content_block_stop"), "Should NOT have content_block_stop for partial stream");
|
||||
|
||||
// Should NOT have completion events
|
||||
assert!(!output.contains("event: message_delta"), "Should NOT have message_delta");
|
||||
assert!(!output.contains("event: message_stop"), "Should NOT have message_stop");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Partial transformation: OpenAI → Anthropic (stream interrupted)");
|
||||
println!("✓ Injected: message_start, content_block_start at beginning");
|
||||
println!("✓ Incremental deltas: {} events (ALL content preserved!)", delta_count);
|
||||
println!("✓ NO completion events (partial stream, no [DONE])");
|
||||
println!("✓ Buffer maintains Anthropic protocol for active streams\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_tool_calling_to_anthropic_transformation() {
|
||||
// OpenAI ChatCompletions tool calling stream
|
||||
let raw_input = r#"data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_2Uzw0AEZQeOex2CP2TKjcLKc","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"obfuscation":"uSpCcO"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"24WSqt08jtf"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"6CleV8twTxkKYg"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"San"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Francisco"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"1XLz89l3v"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":","}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"sh"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" CA"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"obfuscation":"I"}
|
||||
|
||||
data: [DONE]"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 3: OpenAI Tool Calling → Anthropic Messages API Transformation");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (OpenAI ChatCompletions with Tool Calls):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Parse and transform events
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions for tool calling transformation
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
|
||||
// Should have lifecycle events (injected by buffer)
|
||||
assert!(output.contains("event: message_start"), "Should have message_start (injected)");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start");
|
||||
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
|
||||
assert!(output.contains("event: message_delta"), "Should have message_delta");
|
||||
assert!(output.contains("event: message_stop"), "Should have message_stop");
|
||||
|
||||
// Should have tool_use content block
|
||||
assert!(output.contains("\"type\":\"tool_use\""), "Should have tool_use type");
|
||||
assert!(output.contains("\"name\":\"get_weather\""), "Should have correct function name");
|
||||
assert!(output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""), "Should have correct tool call ID");
|
||||
|
||||
// Count input_json_delta events - should match the number of argument chunks
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert!(delta_count >= 8, "Should have at least 8 input_json_delta events");
|
||||
|
||||
// Verify argument deltas are present
|
||||
assert!(output.contains("\"type\":\"input_json_delta\""), "Should have input_json_delta type");
|
||||
assert!(output.contains("\"partial_json\":"), "Should have partial_json field");
|
||||
|
||||
// Verify the accumulated arguments contain the location
|
||||
assert!(output.contains("San"), "Arguments should contain 'San'");
|
||||
assert!(output.contains("Francisco"), "Arguments should contain 'Francisco'");
|
||||
assert!(output.contains("CA"), "Arguments should contain 'CA'");
|
||||
|
||||
// Verify stop reason is tool_use
|
||||
assert!(output.contains("\"stop_reason\":\"tool_use\""), "Should have stop_reason as tool_use");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Complete tool calling transformation: OpenAI → Anthropic Messages API");
|
||||
println!("✓ Injected lifecycle: message_start, content_block_stop");
|
||||
println!("✓ Tool metadata: name='get_weather', id='call_2Uzw0AEZQeOex2CP2TKjcLKc'");
|
||||
println!("✓ Argument deltas: {} events", delta_count);
|
||||
println!("✓ Complete JSON arguments: '{{\"location\":\"San Francisco, CA\"}}'");
|
||||
println!("✓ Stop reason: tool_use");
|
||||
println!("✓ Proper Anthropic tool_use protocol\n");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
|
||||
/// OpenAI Chat Completions SSE Stream Buffer for when client and upstream APIs match.
|
||||
pub struct OpenAIChatCompletionsStreamBuffer {
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl OpenAIChatCompletionsStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffered_events: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for OpenAIChatCompletionsStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
// Skip ping messages
|
||||
if event.should_skip() {
|
||||
return;
|
||||
}
|
||||
|
||||
// For OpenAI Chat Completions, events are already properly transformed
|
||||
// Just accumulate them for later wire transmission
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
// No finalization needed for OpenAI Chat Completions
|
||||
// The [DONE] marker is already handled by the transformation layer
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
let event_bytes: Vec<u8> = event.into();
|
||||
buffer.extend_from_slice(&event_bytes);
|
||||
}
|
||||
buffer
|
||||
}
|
||||
}
|
||||
6
crates/hermesllm/src/apis/streaming_shapes/mod.rs
Normal file
6
crates/hermesllm/src/apis/streaming_shapes/mod.rs
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
pub mod sse;
|
||||
pub mod amazon_bedrock_binary_frame;
|
||||
pub mod anthropic_streaming_buffer;
|
||||
pub mod chat_completions_streaming_buffer;
|
||||
pub mod passthrough_streaming_buffer;
|
||||
pub mod responses_api_streaming_buffer;
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
|
||||
/// Passthrough SSE Stream Buffer for when client and upstream APIs match.
|
||||
pub struct PassthroughStreamBuffer {
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl PassthroughStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffered_events: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for PassthroughStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
// Skip ping messages
|
||||
if event.should_skip() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Skip events with empty transformed lines (e.g., suppressed event-only lines)
|
||||
if event.sse_transformed_lines.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Just accumulate events as-is
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
// No finalization needed for passthrough - just convert accumulated events to bytes
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
let event_bytes: Vec<u8> = event.into();
|
||||
buffer.extend_from_slice(&event_bytes);
|
||||
}
|
||||
buffer
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
|
||||
use crate::apis::streaming_shapes::sse::{SseStreamIter, SseStreamBufferTrait};
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_passthrough_buffer() {
|
||||
let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
|
||||
|
||||
data: [DONE]"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 1: ChatCompletions Passthrough Buffer");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (ChatCompletions):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Parse and process through buffer
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = PassthroughStreamBuffer::new();
|
||||
|
||||
for event in stream_iter {
|
||||
buffer.add_transformed_event(event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ChatCompletions - Passthrough):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty());
|
||||
assert!(output.contains("chatcmpl-123"));
|
||||
assert!(output.contains("[DONE]"));
|
||||
assert_eq!(raw_input.trim(), output.trim(), "Passthrough should preserve input");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Passthrough buffer: input = output (no transformation)");
|
||||
println!("✓ All events preserved including [DONE]");
|
||||
println!("✓ Function calling events preserved\n");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,600 @@
|
|||
use std::collections::HashMap;
|
||||
use log::debug;
|
||||
use crate::apis::openai_responses::{
|
||||
ResponsesAPIStreamEvent, ResponsesAPIResponse, OutputItem, OutputItemStatus,
|
||||
ResponseStatus, TextConfig, TextFormat, Reasoning,
|
||||
};
|
||||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
|
||||
/// Helper to convert ResponseAPIStreamEvent to SseEvent
|
||||
fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
|
||||
let event_type = match &event {
|
||||
ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created",
|
||||
ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress",
|
||||
ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed",
|
||||
ResponsesAPIStreamEvent::ResponseOutputItemAdded { .. } => "response.output_item.added",
|
||||
ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done",
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta",
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done",
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => "response.function_call_arguments.delta",
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => "response.function_call_arguments.done",
|
||||
unknown => {
|
||||
debug!("Unknown ResponsesAPIStreamEvent type encountered: {:?}", unknown);
|
||||
"unknown"
|
||||
}
|
||||
};
|
||||
|
||||
let json_data = match serde_json::to_string(&event) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
debug!("Error serializing ResponsesAPIStreamEvent to JSON: {}", e);
|
||||
String::new()
|
||||
}
|
||||
};
|
||||
let wire_format: String = event.into();
|
||||
|
||||
SseEvent {
|
||||
data: Some(json_data),
|
||||
event: Some(event_type.to_string()),
|
||||
raw_line: wire_format.clone(),
|
||||
sse_transformed_lines: wire_format,
|
||||
provider_stream_response: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// SSE Stream Buffer for ResponsesAPIStreamEvent with full lifecycle management.
|
||||
///
|
||||
/// This buffer manages the wire format for v1/responses streaming, handling
|
||||
/// delta events and emitting complete lifecycle events.
|
||||
///
|
||||
pub struct ResponsesAPIStreamBuffer {
|
||||
/// Sequence number for events
|
||||
sequence_number: i32,
|
||||
|
||||
/// Track item IDs by output index
|
||||
item_ids: HashMap<i32, String>,
|
||||
|
||||
/// Response metadata
|
||||
response_id: Option<String>,
|
||||
model: Option<String>,
|
||||
created_at: Option<i64>,
|
||||
|
||||
/// Lifecycle state flags
|
||||
created_emitted: bool,
|
||||
in_progress_emitted: bool,
|
||||
|
||||
/// Track which output items we've added
|
||||
output_items_added: HashMap<i32, String>, // output_index -> item_id
|
||||
|
||||
/// Accumulated content by item_id
|
||||
text_content: HashMap<String, String>,
|
||||
function_arguments: HashMap<String, String>,
|
||||
|
||||
/// Tool call metadata by output_index
|
||||
tool_call_metadata: HashMap<i32, (String, String)>, // output_index -> (call_id, name)
|
||||
|
||||
/// Final completed response (for logging/tracing/persistence)
|
||||
completed_response: Option<ResponsesAPIResponse>,
|
||||
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl ResponsesAPIStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sequence_number: 0,
|
||||
item_ids: HashMap::new(),
|
||||
response_id: None,
|
||||
model: None,
|
||||
created_at: None,
|
||||
created_emitted: false,
|
||||
in_progress_emitted: false,
|
||||
output_items_added: HashMap::new(),
|
||||
text_content: HashMap::new(),
|
||||
function_arguments: HashMap::new(),
|
||||
tool_call_metadata: HashMap::new(),
|
||||
completed_response: None,
|
||||
buffered_events: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn next_sequence_number(&mut self) -> i32 {
|
||||
let seq = self.sequence_number;
|
||||
self.sequence_number += 1;
|
||||
seq
|
||||
}
|
||||
|
||||
fn generate_item_id(prefix: &str) -> String {
|
||||
format!("{}_{}", prefix, uuid::Uuid::new_v4().to_string().replace("-", ""))
|
||||
}
|
||||
|
||||
fn get_or_create_item_id(&mut self, output_index: i32, prefix: &str) -> String {
|
||||
if let Some(id) = self.item_ids.get(&output_index) {
|
||||
return id.clone();
|
||||
}
|
||||
let id = ResponsesAPIStreamBuffer::generate_item_id(prefix);
|
||||
self.item_ids.insert(output_index, id.clone());
|
||||
id
|
||||
}
|
||||
|
||||
/// Create response.created event
|
||||
fn create_response_created_event(&mut self) -> SseEvent {
|
||||
let response = self.build_response(ResponseStatus::InProgress);
|
||||
let event = ResponsesAPIStreamEvent::ResponseCreated {
|
||||
response,
|
||||
sequence_number: self.next_sequence_number(),
|
||||
};
|
||||
event_to_sse(event)
|
||||
}
|
||||
|
||||
/// Create response.in_progress event
|
||||
fn create_response_in_progress_event(&mut self) -> SseEvent {
|
||||
let response = self.build_response(ResponseStatus::InProgress);
|
||||
let event = ResponsesAPIStreamEvent::ResponseInProgress {
|
||||
response,
|
||||
sequence_number: self.next_sequence_number(),
|
||||
};
|
||||
event_to_sse(event)
|
||||
}
|
||||
|
||||
/// Create output_item.added event for text
|
||||
fn create_output_item_added_event(&mut self, output_index: i32, item_id: &str) -> SseEvent {
|
||||
let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded {
|
||||
output_index,
|
||||
item: OutputItem::Message {
|
||||
id: item_id.to_string(),
|
||||
status: OutputItemStatus::InProgress,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
},
|
||||
sequence_number: self.next_sequence_number(),
|
||||
};
|
||||
event_to_sse(event)
|
||||
}
|
||||
|
||||
/// Create output_item.added event for tool call
|
||||
fn create_tool_call_added_event(&mut self, output_index: i32, item_id: &str, call_id: &str, name: &str) -> SseEvent {
|
||||
let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded {
|
||||
output_index,
|
||||
item: OutputItem::FunctionCall {
|
||||
id: item_id.to_string(),
|
||||
status: OutputItemStatus::InProgress,
|
||||
call_id: call_id.to_string(),
|
||||
name: Some(name.to_string()),
|
||||
arguments: Some(String::new()),
|
||||
},
|
||||
sequence_number: self.next_sequence_number(),
|
||||
};
|
||||
event_to_sse(event)
|
||||
}
|
||||
|
||||
/// Build the base response object with current state
|
||||
fn build_response(&self, status: ResponseStatus) -> ResponsesAPIResponse {
|
||||
ResponsesAPIResponse {
|
||||
id: self.response_id.clone().unwrap_or_default(),
|
||||
object: "response".to_string(),
|
||||
created_at: self.created_at.unwrap_or(0),
|
||||
status,
|
||||
error: None,
|
||||
incomplete_details: None,
|
||||
instructions: None,
|
||||
model: self.model.clone().unwrap_or_else(|| "unknown".to_string()),
|
||||
output: vec![],
|
||||
usage: None,
|
||||
parallel_tool_calls: true,
|
||||
conversation: None,
|
||||
previous_response_id: None,
|
||||
tools: vec![],
|
||||
tool_choice: "auto".to_string(),
|
||||
temperature: 1.0,
|
||||
top_p: 1.0,
|
||||
metadata: HashMap::new(),
|
||||
truncation: Some("disabled".to_string()),
|
||||
max_output_tokens: None,
|
||||
reasoning: Some(Reasoning {
|
||||
effort: None,
|
||||
summary: None,
|
||||
}),
|
||||
store: Some(true),
|
||||
text: Some(TextConfig {
|
||||
format: TextFormat::Text,
|
||||
}),
|
||||
audio: None,
|
||||
modalities: None,
|
||||
service_tier: Some("auto".to_string()),
|
||||
background: Some(false),
|
||||
top_logprobs: Some(0),
|
||||
max_tool_calls: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the completed response after finalization (for logging/tracing/persistence)
|
||||
pub fn get_completed_response(&self) -> Option<&ResponsesAPIResponse> {
|
||||
self.completed_response.as_ref()
|
||||
}
|
||||
|
||||
/// Finalize the response by emitting all *.done events and response.completed.
|
||||
/// Call this when the stream is complete (after seeing [DONE] or end_of_stream).
|
||||
pub fn finalize(&mut self) {
|
||||
let mut events = Vec::new();
|
||||
|
||||
// Emit done events for all accumulated content
|
||||
|
||||
// Text content done events
|
||||
let text_items: Vec<_> = self.text_content.iter().map(|(id, content)| (id.clone(), content.clone())).collect();
|
||||
for (item_id, content) in text_items {
|
||||
let output_index = self.output_items_added.iter()
|
||||
.find(|(_, id)| **id == item_id)
|
||||
.map(|(idx, _)| *idx)
|
||||
.unwrap_or(0);
|
||||
|
||||
let seq1 = self.next_sequence_number();
|
||||
let text_done_event = ResponsesAPIStreamEvent::ResponseOutputTextDone {
|
||||
item_id: item_id.clone(),
|
||||
output_index,
|
||||
content_index: 0,
|
||||
text: content.clone(),
|
||||
logprobs: vec![],
|
||||
sequence_number: seq1,
|
||||
};
|
||||
events.push(event_to_sse(text_done_event));
|
||||
|
||||
let seq2 = self.next_sequence_number();
|
||||
let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone {
|
||||
output_index,
|
||||
item: OutputItem::Message {
|
||||
id: item_id.clone(),
|
||||
status: OutputItemStatus::Completed,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
},
|
||||
sequence_number: seq2,
|
||||
};
|
||||
events.push(event_to_sse(item_done_event));
|
||||
}
|
||||
|
||||
// Function call done events
|
||||
let func_items: Vec<_> = self.function_arguments.iter().map(|(id, args)| (id.clone(), args.clone())).collect();
|
||||
for (item_id, arguments) in func_items {
|
||||
let output_index = self.output_items_added.iter()
|
||||
.find(|(_, id)| **id == item_id)
|
||||
.map(|(idx, _)| *idx)
|
||||
.unwrap_or(0);
|
||||
|
||||
let seq1 = self.next_sequence_number();
|
||||
let args_done_event = ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone {
|
||||
output_index,
|
||||
item_id: item_id.clone(),
|
||||
arguments: arguments.clone(),
|
||||
sequence_number: seq1,
|
||||
};
|
||||
events.push(event_to_sse(args_done_event));
|
||||
|
||||
let (call_id, name) = self.tool_call_metadata.get(&output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
|
||||
let seq2 = self.next_sequence_number();
|
||||
let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone {
|
||||
output_index,
|
||||
item: OutputItem::FunctionCall {
|
||||
id: item_id.clone(),
|
||||
status: OutputItemStatus::Completed,
|
||||
call_id,
|
||||
name: Some(name),
|
||||
arguments: Some(arguments.clone()),
|
||||
},
|
||||
sequence_number: seq2,
|
||||
};
|
||||
events.push(event_to_sse(item_done_event));
|
||||
}
|
||||
|
||||
// Build final response
|
||||
let mut output_items = Vec::new();
|
||||
|
||||
// Add tool calls to output
|
||||
for (item_id, arguments) in &self.function_arguments {
|
||||
let output_index = self.output_items_added.iter()
|
||||
.find(|(_, id)| *id == item_id)
|
||||
.map(|(idx, _)| *idx)
|
||||
.unwrap_or(0);
|
||||
|
||||
let (call_id, name) = self.tool_call_metadata.get(&output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
|
||||
output_items.push(OutputItem::FunctionCall {
|
||||
id: item_id.clone(),
|
||||
status: OutputItemStatus::Completed,
|
||||
call_id,
|
||||
name: Some(name),
|
||||
arguments: Some(arguments.clone()),
|
||||
});
|
||||
}
|
||||
|
||||
let mut final_response = self.build_response(ResponseStatus::Completed);
|
||||
final_response.output = output_items;
|
||||
|
||||
// Store completed response
|
||||
self.completed_response = Some(final_response.clone());
|
||||
|
||||
// Emit response.completed
|
||||
let seq_final = self.next_sequence_number();
|
||||
let completed_event = ResponsesAPIStreamEvent::ResponseCompleted {
|
||||
response: final_response,
|
||||
sequence_number: seq_final,
|
||||
};
|
||||
events.push(event_to_sse(completed_event));
|
||||
|
||||
// Add all finalization events to the buffer
|
||||
self.buffered_events.extend(events);
|
||||
}
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
// Skip ping messages
|
||||
if event.should_skip() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle [DONE] marker - trigger finalization
|
||||
if event.is_done() {
|
||||
self.finalize();
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract the ResponseAPIStreamEvent from the SseEvent's provider_stream_response
|
||||
let provider_response = match event.provider_stream_response.as_ref() {
|
||||
Some(response) => response,
|
||||
None => {
|
||||
eprintln!("Warning: Event missing provider_stream_response");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Extract ResponseAPIStreamEvent from the enum
|
||||
let stream_event = match provider_response {
|
||||
crate::providers::streaming_response::ProviderStreamResponseType::ResponseAPIStreamEvent(evt) => evt,
|
||||
_ => {
|
||||
eprintln!("Warning: Expected ResponseAPIStreamEvent in provider_stream_response");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
|
||||
// Emit lifecycle events if not yet emitted
|
||||
if !self.created_emitted {
|
||||
// Initialize metadata from first event if needed
|
||||
if self.response_id.is_none() {
|
||||
self.response_id = Some(format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", "")));
|
||||
self.created_at = Some(std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64);
|
||||
self.model = Some("unknown".to_string()); // Will be set by caller if available
|
||||
}
|
||||
|
||||
events.push(self.create_response_created_event());
|
||||
self.created_emitted = true;
|
||||
}
|
||||
|
||||
if !self.in_progress_emitted {
|
||||
events.push(self.create_response_in_progress_event());
|
||||
self.in_progress_emitted = true;
|
||||
}
|
||||
|
||||
// Process the delta event
|
||||
match stream_event {
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDelta { output_index, delta, .. } => {
|
||||
let item_id = self.get_or_create_item_id(*output_index, "msg");
|
||||
|
||||
// Emit output_item.added if this is the first time we see this output index
|
||||
if !self.output_items_added.contains_key(output_index) {
|
||||
self.output_items_added.insert(*output_index, item_id.clone());
|
||||
events.push(self.create_output_item_added_event(*output_index, &item_id));
|
||||
}
|
||||
|
||||
// Accumulate text content
|
||||
self.text_content.entry(item_id.clone())
|
||||
.and_modify(|content| content.push_str(delta))
|
||||
.or_insert_with(|| delta.clone());
|
||||
|
||||
// Emit text delta with filled-in item_id and sequence_number
|
||||
let mut delta_event = stream_event.clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseOutputTextDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
|
||||
*id = item_id;
|
||||
*seq = self.next_sequence_number();
|
||||
}
|
||||
events.push(event_to_sse(delta_event));
|
||||
}
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { output_index, delta, call_id, name, .. } => {
|
||||
let item_id = self.get_or_create_item_id(*output_index, "fc");
|
||||
|
||||
// Store metadata if provided (from initial tool call event)
|
||||
if let (Some(cid), Some(n)) = (call_id, name) {
|
||||
self.tool_call_metadata.insert(*output_index, (cid.clone(), n.clone()));
|
||||
}
|
||||
|
||||
// Emit output_item.added if this is the first time we see this tool call
|
||||
if !self.output_items_added.contains_key(output_index) {
|
||||
self.output_items_added.insert(*output_index, item_id.clone());
|
||||
|
||||
// For tool calls, we need call_id and name from metadata
|
||||
// These should now be populated from the event itself
|
||||
let (call_id, name) = self.tool_call_metadata.get(output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
|
||||
events.push(self.create_tool_call_added_event(*output_index, &item_id, &call_id, &name));
|
||||
}
|
||||
|
||||
// Accumulate function arguments
|
||||
self.function_arguments.entry(item_id.clone())
|
||||
.and_modify(|args| args.push_str(delta))
|
||||
.or_insert_with(|| delta.clone());
|
||||
|
||||
// Emit function call arguments delta with filled-in item_id and sequence_number
|
||||
let mut delta_event = stream_event.clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
|
||||
*id = item_id;
|
||||
*seq = self.next_sequence_number();
|
||||
}
|
||||
events.push(event_to_sse(delta_event));
|
||||
}
|
||||
_ => {
|
||||
// For other event types, just pass through with sequence number
|
||||
let other_event = stream_event.clone();
|
||||
// TODO: Add sequence number to other event types if needed
|
||||
events.push(event_to_sse(other_event));
|
||||
}
|
||||
}
|
||||
|
||||
// Store all generated events in the buffer
|
||||
self.buffered_events.extend(events);
|
||||
}
|
||||
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
// For Responses API, we need special handling:
|
||||
// - Most events are already in buffered_events from add_transformed_event
|
||||
// - We should NOT finalize here - finalization happens when we detect [DONE] or end of stream
|
||||
// - Just flush the accumulated events and clear the buffer
|
||||
|
||||
// Convert all accumulated events to bytes and clear buffer
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
let event_bytes: Vec<u8> = event.into();
|
||||
buffer.extend_from_slice(&event_bytes);
|
||||
}
|
||||
buffer
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::streaming_shapes::sse::SseStreamIter;
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_to_responses_api_transformation() {
|
||||
// ChatCompletions input that will be transformed to ResponsesAPI
|
||||
let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 2: ChatCompletions → ResponsesAPI Transformation (with [DONE])");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (ChatCompletions):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation
|
||||
let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Parse events and apply transformation
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = ResponsesAPIStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
// Transform the event using the client/upstream APIs
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("response.created"), "Should have response.created");
|
||||
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
|
||||
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
|
||||
assert!(output.contains("response.output_text.delta"), "Should have text deltas");
|
||||
assert!(output.contains("response.output_text.done"), "Should have text.done");
|
||||
assert!(output.contains("response.output_item.done"), "Should have output_item.done");
|
||||
assert!(output.contains("response.completed"), "Should have response.completed");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Lifecycle events: response.created, response.in_progress, response.completed");
|
||||
println!("✓ Output item lifecycle: output_item.added, output_item.done");
|
||||
println!("✓ Text streaming: output_text.delta (2 deltas), output_text.done");
|
||||
println!("✓ Complete transformation with finalization ([DONE] processed)\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_streaming_incremental_output() {
|
||||
let raw_input = r#"data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_mD5ggLKk3SMKGPFqFdcpKg6q","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"obfuscation":"PCFrpy"}
|
||||
|
||||
data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"TC58A3QEIx8"}
|
||||
|
||||
data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"PK4oFzlVlGTUP5"}"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 3: Partial Streaming - Function Calling (NO [DONE])");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (ChatCompletions - NO [DONE]):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation
|
||||
let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Transform all events
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = ResponsesAPIStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(output.contains("response.created"), "Should have response.created");
|
||||
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
|
||||
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
|
||||
assert!(output.contains("\"type\":\"function_call\""), "Should be function_call type");
|
||||
assert!(output.contains("\"name\":\"get_weather\""), "Should have function name");
|
||||
assert!(output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""), "Should have correct call_id");
|
||||
|
||||
let delta_count = output.matches("event: response.function_call_arguments.delta").count();
|
||||
assert_eq!(delta_count, 4, "Should have 4 delta events");
|
||||
|
||||
assert!(!output.contains("response.function_call_arguments.done"), "Should NOT have arguments.done");
|
||||
assert!(!output.contains("response.output_item.done"), "Should NOT have output_item.done");
|
||||
assert!(!output.contains("response.completed"), "Should NOT have response.completed");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Lifecycle events: response.created, response.in_progress");
|
||||
println!("✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'");
|
||||
println!("✓ Incremental deltas: 4 events (1 initial + 3 argument chunks)");
|
||||
println!("✓ NO completion events (partial stream, no [DONE])");
|
||||
println!("✓ Arguments accumulated: '{{\"location\":\"'\n");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,10 +1,73 @@
|
|||
use crate::providers::response::ProviderStreamResponse;
|
||||
use crate::providers::response::ProviderStreamResponseType;
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
|
||||
use crate::apis::streaming_shapes::anthropic_streaming_buffer::AnthropicMessagesStreamBuffer;
|
||||
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
|
||||
use crate::apis::streaming_shapes::responses_api_streaming_buffer::ResponsesAPIStreamBuffer;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
||||
/// Trait defining the interface for SSE stream buffers.
|
||||
///
|
||||
/// This trait is implemented by both the enum `SseStreamBuffer` (for zero-cost dispatch)
|
||||
/// and individual buffer implementations (for direct use).
|
||||
///
|
||||
pub trait SseStreamBufferTrait: Send + Sync {
|
||||
/// Add a transformed SSE event to the buffer.
|
||||
///
|
||||
/// The buffer may inject additional events as needed based on internal state.
|
||||
/// For example, Anthropic buffers inject ContentBlockStart before the first ContentBlockDelta.
|
||||
///
|
||||
/// All events (original + injected) are accumulated internally for the next `into_bytes()` call.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `event` - A transformed SSE event to accumulate
|
||||
fn add_transformed_event(&mut self, event: SseEvent);
|
||||
|
||||
/// Get bytes for all accumulated events since the last call.
|
||||
///
|
||||
/// This method:
|
||||
/// - Converts all buffered events to wire format bytes
|
||||
/// - Clears the internal event buffer
|
||||
/// - Preserves state for subsequent `add_transformed_event()` calls
|
||||
///
|
||||
/// Call this after processing each chunk of upstream events to get bytes for immediate transmission.
|
||||
///
|
||||
/// # Returns
|
||||
/// Bytes ready for wire transmission (may be empty if no events were accumulated)
|
||||
fn into_bytes(&mut self) -> Vec<u8>;
|
||||
}
|
||||
|
||||
/// Unified SSE Stream Buffer enum that provides a zero-cost abstraction
|
||||
pub enum SseStreamBuffer {
|
||||
Passthrough(PassthroughStreamBuffer),
|
||||
OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer),
|
||||
AnthropicMessages(AnthropicMessagesStreamBuffer),
|
||||
OpenAIResponses(ResponsesAPIStreamBuffer),
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for SseStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
match self {
|
||||
Self::Passthrough(buffer) => buffer.add_transformed_event(event),
|
||||
Self::OpenAIChatCompletions(buffer) => buffer.add_transformed_event(event),
|
||||
Self::AnthropicMessages(buffer) => buffer.add_transformed_event(event),
|
||||
Self::OpenAIResponses(buffer) => buffer.add_transformed_event(event),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
match self {
|
||||
Self::Passthrough(buffer) => buffer.into_bytes(),
|
||||
Self::OpenAIChatCompletions(buffer) => buffer.into_bytes(),
|
||||
Self::AnthropicMessages(buffer) => buffer.into_bytes(),
|
||||
Self::OpenAIResponses(buffer) => buffer.into_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SSE EVENT CONTAINER
|
||||
// ============================================================================
|
||||
|
|
@ -22,16 +85,31 @@ pub struct SseEvent {
|
|||
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
pub sse_transformed_lines: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
|
||||
}
|
||||
|
||||
impl SseEvent {
|
||||
/// Create an SseEvent from a ProviderStreamResponseType
|
||||
/// This is useful for binary frame formats (like Bedrock) that need to be converted to SSE
|
||||
pub fn from_provider_response(response: ProviderStreamResponseType) -> Self {
|
||||
// Convert the provider response to SSE format string
|
||||
let sse_string: String = response.clone().into();
|
||||
|
||||
SseEvent {
|
||||
data: None, // Data is embedded in sse_transformed_lines
|
||||
event: None, // Event type is embedded in sse_transformed_lines
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: Some(response),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this event represents the end of the stream
|
||||
pub fn is_done(&self) -> bool {
|
||||
self.data == Some("[DONE]".into())
|
||||
self.data == Some("[DONE]".into()) || self.event == Some("message_stop".into())
|
||||
}
|
||||
|
||||
/// Check if this event should be skipped during processing
|
||||
|
|
@ -61,23 +139,35 @@ impl FromStr for SseEvent {
|
|||
type Err = SseParseError;
|
||||
|
||||
fn from_str(line: &str) -> Result<Self, Self::Err> {
|
||||
if line.starts_with("data: ") {
|
||||
let data: String = line[6..].to_string(); // Remove "data: " prefix
|
||||
if data.is_empty() {
|
||||
// Trim leading/trailing whitespace for parsing
|
||||
let trimmed_line = line.trim();
|
||||
|
||||
// Skip empty or whitespace-only lines (SSE event separators)
|
||||
if trimmed_line.is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty line (SSE event separator)".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if trimmed_line.starts_with("data: ") {
|
||||
let data: String = trimmed_line[6..].to_string(); // Remove "data: " prefix
|
||||
// Allow empty data content after "data: " prefix
|
||||
// This handles cases like "data: " followed by newline
|
||||
if data.trim().is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty data field is not a valid SSE event".to_string(),
|
||||
message: "Empty data field after 'data: ' prefix".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(SseEvent {
|
||||
data: Some(data),
|
||||
event: None,
|
||||
raw_line: line.to_string(),
|
||||
sse_transform_buffer: line.to_string(),
|
||||
// Preserve original line format for passthrough, use trimmed for transformations
|
||||
sse_transformed_lines: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else if line.starts_with("event: ") {
|
||||
//used by Anthropic
|
||||
let event_type = line[7..].to_string();
|
||||
} else if trimmed_line.starts_with("event: ") {
|
||||
let event_type = trimmed_line[7..].to_string();
|
||||
if event_type.is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty event field is not a valid SSE event".to_string(),
|
||||
|
|
@ -87,12 +177,13 @@ impl FromStr for SseEvent {
|
|||
data: None,
|
||||
event: Some(event_type),
|
||||
raw_line: line.to_string(),
|
||||
sse_transform_buffer: line.to_string(),
|
||||
// Preserve original line format for passthrough, use trimmed for transformations
|
||||
sse_transformed_lines: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else {
|
||||
Err(SseParseError {
|
||||
message: format!("Line does not start with 'data: ' or 'event: ': {}", line),
|
||||
message: format!("Line does not start with 'data: ' or 'event: ': {}", trimmed_line),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -100,14 +191,14 @@ impl FromStr for SseEvent {
|
|||
|
||||
impl fmt::Display for SseEvent {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.sse_transform_buffer)
|
||||
write!(f, "{}", self.sse_transformed_lines)
|
||||
}
|
||||
}
|
||||
|
||||
// Into implementation to convert SseEvent to bytes for response buffer
|
||||
impl Into<Vec<u8>> for SseEvent {
|
||||
fn into(self) -> Vec<u8> {
|
||||
format!("{}\n\n", self.sse_transform_buffer).into_bytes()
|
||||
format!("{}\n\n", self.sse_transformed_lines).into_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -4,9 +4,10 @@ use std::fmt;
|
|||
|
||||
/// Unified enum representing all supported API endpoints across providers
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SupportedAPIs {
|
||||
pub enum SupportedAPIsFromClient {
|
||||
OpenAIChatCompletions(OpenAIApi),
|
||||
AnthropicMessagesAPI(AnthropicApi),
|
||||
OpenAIResponsesAPI(OpenAIApi),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
|
|
@ -15,17 +16,21 @@ pub enum SupportedUpstreamAPIs {
|
|||
AnthropicMessagesAPI(AnthropicApi),
|
||||
AmazonBedrockConverse(AmazonBedrockApi),
|
||||
AmazonBedrockConverseStream(AmazonBedrockApi),
|
||||
OpenAIResponsesAPI(OpenAIApi),
|
||||
}
|
||||
|
||||
impl fmt::Display for SupportedAPIs {
|
||||
impl fmt::Display for SupportedAPIsFromClient {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SupportedAPIs::OpenAIChatCompletions(api) => {
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(api) => {
|
||||
write!(f, "OpenAI ({})", api.endpoint())
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(api) => {
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(api) => {
|
||||
write!(f, "Anthropic AI ({})", api.endpoint())
|
||||
}
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => {
|
||||
write!(f, "OpenAI Responses ({})", api.endpoint())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -45,19 +50,27 @@ impl fmt::Display for SupportedUpstreamAPIs {
|
|||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(api) => {
|
||||
write!(f, "Amazon Bedrock ({})", api.endpoint())
|
||||
}
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(api) => {
|
||||
write!(f, "OpenAI Responses ({})", api.endpoint())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SupportedAPIs {
|
||||
impl SupportedAPIsFromClient {
|
||||
/// Create a SupportedApi from an endpoint path
|
||||
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
|
||||
return Some(SupportedAPIs::OpenAIChatCompletions(openai_api));
|
||||
// Check if this is the Responses API endpoint
|
||||
if openai_api == OpenAIApi::Responses {
|
||||
return Some(SupportedAPIsFromClient::OpenAIResponsesAPI(openai_api));
|
||||
}
|
||||
// Otherwise it's ChatCompletions
|
||||
return Some(SupportedAPIsFromClient::OpenAIChatCompletions(openai_api));
|
||||
}
|
||||
|
||||
if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) {
|
||||
return Some(SupportedAPIs::AnthropicMessagesAPI(anthropic_api));
|
||||
return Some(SupportedAPIsFromClient::AnthropicMessagesAPI(anthropic_api));
|
||||
}
|
||||
|
||||
None
|
||||
|
|
@ -66,8 +79,9 @@ impl SupportedAPIs {
|
|||
/// Get the endpoint path for this API
|
||||
pub fn endpoint(&self) -> &'static str {
|
||||
match self {
|
||||
SupportedAPIs::OpenAIChatCompletions(api) => api.endpoint(),
|
||||
SupportedAPIs::AnthropicMessagesAPI(api) => api.endpoint(),
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(api) => api.endpoint(),
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(api) => api.endpoint(),
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => api.endpoint(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -94,8 +108,62 @@ impl SupportedAPIs {
|
|||
}
|
||||
};
|
||||
|
||||
// Helper function to route based on provider with a specific endpoint suffix
|
||||
let route_by_provider = |endpoint_suffix: &str| -> String {
|
||||
match provider_id {
|
||||
ProviderId::Groq => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/openai", request_path)
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
}
|
||||
ProviderId::Zhipu => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/api/paas/v4", endpoint_suffix)
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
}
|
||||
ProviderId::Qwen => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/compatible-mode/v1", endpoint_suffix)
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
}
|
||||
ProviderId::AzureOpenAI => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
let suffix = endpoint_suffix.trim_start_matches('/');
|
||||
build_endpoint("/openai/deployments", &format!("/{}/{}?api-version=2025-01-01-preview", model_id, suffix))
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
}
|
||||
ProviderId::Gemini => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/v1beta/openai", endpoint_suffix)
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
}
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
if !is_streaming {
|
||||
build_endpoint("", &format!("/model/{}/converse", model_id))
|
||||
} else {
|
||||
build_endpoint("", &format!("/model/{}/converse-stream", model_id))
|
||||
}
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
}
|
||||
_ => build_endpoint("/v1", endpoint_suffix),
|
||||
}
|
||||
};
|
||||
|
||||
match self {
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
|
||||
ProviderId::Anthropic => build_endpoint("/v1", "/messages"),
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") && !is_streaming {
|
||||
|
|
@ -108,55 +176,19 @@ impl SupportedAPIs {
|
|||
}
|
||||
_ => build_endpoint("/v1", "/chat/completions"),
|
||||
},
|
||||
_ => match provider_id {
|
||||
ProviderId::Groq => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/openai", request_path)
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
|
||||
// For Responses API, check if provider supports it, otherwise translate to chat/completions
|
||||
match provider_id {
|
||||
// OpenAI and compatible providers that support /v1/responses
|
||||
ProviderId::OpenAI => route_by_provider("/responses"),
|
||||
// All other providers: translate to /chat/completions
|
||||
_ => route_by_provider("/chat/completions"),
|
||||
}
|
||||
ProviderId::Zhipu => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/api/paas/v4", "/chat/completions")
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
}
|
||||
ProviderId::Qwen => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/compatible-mode/v1", "/chat/completions")
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
}
|
||||
ProviderId::AzureOpenAI => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/openai/deployments", &format!("/{}/chat/completions?api-version=2025-01-01-preview", model_id))
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
}
|
||||
ProviderId::Gemini => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/v1beta/openai", "/chat/completions")
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
}
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
if !is_streaming {
|
||||
build_endpoint("", &format!("/model/{}/converse", model_id))
|
||||
} else {
|
||||
build_endpoint("", &format!("/model/{}/converse-stream", model_id))
|
||||
}
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
}
|
||||
_ => build_endpoint("/v1", "/chat/completions"),
|
||||
},
|
||||
}
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_) => {
|
||||
// For Chat Completions API, use the standard chat/completions path
|
||||
route_by_provider("/chat/completions")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -198,22 +230,23 @@ mod tests {
|
|||
#[test]
|
||||
fn test_is_supported_endpoint() {
|
||||
// OpenAI endpoints
|
||||
assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some());
|
||||
assert!(SupportedAPIsFromClient::from_endpoint("/v1/chat/completions").is_some());
|
||||
// Anthropic endpoints
|
||||
assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some());
|
||||
assert!(SupportedAPIsFromClient::from_endpoint("/v1/messages").is_some());
|
||||
|
||||
// Unsupported endpoints
|
||||
assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some());
|
||||
assert!(!SupportedAPIs::from_endpoint("/v2/chat").is_some());
|
||||
assert!(!SupportedAPIs::from_endpoint("").is_some());
|
||||
assert!(!SupportedAPIsFromClient::from_endpoint("/v1/unknown").is_some());
|
||||
assert!(!SupportedAPIsFromClient::from_endpoint("/v2/chat").is_some());
|
||||
assert!(!SupportedAPIsFromClient::from_endpoint("").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_supported_endpoints() {
|
||||
let endpoints = supported_endpoints();
|
||||
assert_eq!(endpoints.len(), 2); // We have 2 APIs defined
|
||||
assert_eq!(endpoints.len(), 3); // We have 3 APIs defined
|
||||
assert!(endpoints.contains(&"/v1/chat/completions"));
|
||||
assert!(endpoints.contains(&"/v1/messages"));
|
||||
assert!(endpoints.contains(&"/v1/responses"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -263,7 +296,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_target_endpoint_without_base_url_prefix() {
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test default OpenAI provider
|
||||
assert_eq!(
|
||||
|
|
@ -340,7 +373,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_target_endpoint_with_base_url_prefix() {
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test Zhipu with custom base_url_path_prefix
|
||||
assert_eq!(
|
||||
|
|
@ -405,7 +438,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_target_endpoint_with_empty_base_url_prefix() {
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test with just slashes - trims to empty, uses provider default
|
||||
assert_eq!(
|
||||
|
|
@ -434,7 +467,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_amazon_bedrock_endpoints() {
|
||||
let api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
|
||||
// Test Bedrock non-streaming without prefix
|
||||
assert_eq!(
|
||||
|
|
@ -487,7 +520,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_anthropic_messages_endpoint() {
|
||||
let api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
|
||||
// Test Anthropic without prefix
|
||||
assert_eq!(
|
||||
|
|
@ -516,7 +549,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_non_v1_request_paths() {
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test Groq with non-v1 path (should use default)
|
||||
assert_eq!(
|
||||
|
|
@ -557,7 +590,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_azure_openai_with_query_params() {
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test Azure without prefix - should include query params
|
||||
assert_eq!(
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
pub mod endpoints;
|
||||
pub mod lib;
|
||||
pub mod transformer;
|
||||
|
||||
// Re-export the main items for easier access
|
||||
pub use endpoints::{identify_provider, SupportedAPIs};
|
||||
pub use endpoints::*;
|
||||
pub use lib::*;
|
||||
|
||||
// Note: transformer module contains TryFrom trait implementations that are automatically available
|
||||
|
|
|
|||
|
|
@ -1,694 +0,0 @@
|
|||
// Re-export new transformation modules for backward compatibility
|
||||
|
||||
//KEEPING THE TESTS TO MAKE SURE ALL THE REFACTORING DIDN'T BREAK ANYTHING
|
||||
|
||||
// ============================================================================
|
||||
// TESTS
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::apis::anthropic::*;
|
||||
use crate::apis::openai::*;
|
||||
use crate::transforms::*;
|
||||
use serde_json::json;
|
||||
type AnthropicMessagesRequest = MessagesRequest;
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_to_openai_basic_request() {
|
||||
let anthropic_req = AnthropicMessagesRequest {
|
||||
model: "claude-3-sonnet-20240229".to_string(),
|
||||
system: Some(MessagesSystemPrompt::Single("You are helpful".to_string())),
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Hello, world!".to_string()),
|
||||
}],
|
||||
max_tokens: 1024,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(0.9),
|
||||
top_k: Some(50),
|
||||
stream: Some(false),
|
||||
stop_sequences: Some(vec!["STOP".to_string()]),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let openai_req: ChatCompletionsRequest = anthropic_req.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_req.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(openai_req.messages.len(), 2); // system + user message
|
||||
assert_eq!(openai_req.max_completion_tokens, Some(1024));
|
||||
assert_eq!(openai_req.temperature, Some(0.7));
|
||||
assert_eq!(openai_req.top_p, Some(0.9));
|
||||
assert_eq!(openai_req.stream, Some(false));
|
||||
assert_eq!(openai_req.stop, Some(vec!["STOP".to_string()]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip_consistency() {
|
||||
// Test that converting back and forth maintains consistency
|
||||
let original_anthropic = AnthropicMessagesRequest {
|
||||
model: "claude-3-sonnet".to_string(),
|
||||
system: Some(MessagesSystemPrompt::Single("System prompt".to_string())),
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("User message".to_string()),
|
||||
}],
|
||||
max_tokens: 1000,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: Some(0.5),
|
||||
top_p: Some(1.0),
|
||||
top_k: None,
|
||||
stream: Some(false),
|
||||
stop_sequences: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
// Convert to OpenAI and back
|
||||
let openai_req: ChatCompletionsRequest = original_anthropic.clone().try_into().unwrap();
|
||||
let roundtrip_anthropic: AnthropicMessagesRequest = openai_req.try_into().unwrap();
|
||||
|
||||
// Check key fields are preserved
|
||||
assert_eq!(original_anthropic.model, roundtrip_anthropic.model);
|
||||
assert_eq!(
|
||||
original_anthropic.max_tokens,
|
||||
roundtrip_anthropic.max_tokens
|
||||
);
|
||||
assert_eq!(
|
||||
original_anthropic.temperature,
|
||||
roundtrip_anthropic.temperature
|
||||
);
|
||||
assert_eq!(original_anthropic.top_p, roundtrip_anthropic.top_p);
|
||||
assert_eq!(original_anthropic.stream, roundtrip_anthropic.stream);
|
||||
assert_eq!(
|
||||
original_anthropic.messages.len(),
|
||||
roundtrip_anthropic.messages.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_choice_auto() {
|
||||
let anthropic_req = AnthropicMessagesRequest {
|
||||
model: "claude-3".to_string(),
|
||||
system: None,
|
||||
messages: vec![],
|
||||
max_tokens: 100,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
stream: None,
|
||||
stop_sequences: None,
|
||||
tools: Some(vec![MessagesTool {
|
||||
name: "test_tool".to_string(),
|
||||
description: Some("A test tool".to_string()),
|
||||
input_schema: json!({"type": "object"}),
|
||||
}]),
|
||||
tool_choice: Some(MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Auto,
|
||||
name: None,
|
||||
disable_parallel_tool_use: Some(true),
|
||||
}),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let openai_req: ChatCompletionsRequest = anthropic_req.try_into().unwrap();
|
||||
|
||||
assert!(openai_req.tools.is_some());
|
||||
assert_eq!(openai_req.tools.as_ref().unwrap().len(), 1);
|
||||
|
||||
if let Some(ToolChoice::Type(choice)) = openai_req.tool_choice {
|
||||
assert_eq!(choice, ToolChoiceType::Auto);
|
||||
} else {
|
||||
panic!("Expected auto tool choice");
|
||||
}
|
||||
|
||||
assert_eq!(openai_req.parallel_tool_calls, Some(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_max_tokens_used_when_openai_has_none() {
|
||||
// Test that DEFAULT_MAX_TOKENS is used when OpenAI request has no max_tokens
|
||||
let openai_req = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
max_tokens: None, // No max_tokens specified
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let anthropic_req: AnthropicMessagesRequest = openai_req.try_into().unwrap();
|
||||
|
||||
assert_eq!(anthropic_req.max_tokens, DEFAULT_MAX_TOKENS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_message_start_streaming() {
|
||||
let event = MessagesStreamEvent::MessageStart {
|
||||
message: MessagesStreamMessage {
|
||||
id: "msg_stream_123".to_string(),
|
||||
obj_type: "message".to_string(),
|
||||
role: MessagesRole::Assistant,
|
||||
content: vec![],
|
||||
model: "claude-3".to_string(),
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 5,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.id, "msg_stream_123");
|
||||
assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk"));
|
||||
assert_eq!(openai_resp.model, "claude-3");
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert_eq!(choice.index, 0);
|
||||
assert_eq!(choice.delta.role, Some(Role::Assistant));
|
||||
assert_eq!(choice.delta.content, None);
|
||||
assert_eq!(choice.finish_reason, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_content_block_delta_streaming() {
|
||||
let event = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::TextDelta {
|
||||
text: "Hello, world!".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk"));
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert_eq!(choice.index, 0);
|
||||
assert_eq!(choice.delta.content, Some("Hello, world!".to_string()));
|
||||
assert_eq!(choice.delta.role, None);
|
||||
assert_eq!(choice.finish_reason, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_tool_use_streaming() {
|
||||
// Test tool use start
|
||||
let tool_start = MessagesStreamEvent::ContentBlockStart {
|
||||
index: 0,
|
||||
content_block: MessagesContentBlock::ToolUse {
|
||||
id: "call_123".to_string(),
|
||||
name: "get_weather".to_string(),
|
||||
input: json!({}),
|
||||
cache_control: None,
|
||||
},
|
||||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = tool_start.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert!(choice.delta.tool_calls.is_some());
|
||||
|
||||
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].id, Some("call_123".to_string()));
|
||||
assert_eq!(
|
||||
tool_calls[0].function.as_ref().unwrap().name,
|
||||
Some("get_weather".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_tool_input_delta_streaming() {
|
||||
let event = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: r#"{"location": "San Francisco"#.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert!(choice.delta.tool_calls.is_some());
|
||||
|
||||
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(
|
||||
tool_calls[0].function.as_ref().unwrap().arguments,
|
||||
Some(r#"{"location": "San Francisco"#.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_message_delta_with_usage() {
|
||||
let event = MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: MessagesStopReason::EndTurn,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 10,
|
||||
output_tokens: 25,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert_eq!(choice.finish_reason, Some(FinishReason::Stop));
|
||||
|
||||
assert!(openai_resp.usage.is_some());
|
||||
let usage = openai_resp.usage.unwrap();
|
||||
assert_eq!(usage.prompt_tokens, 10);
|
||||
assert_eq!(usage.completion_tokens, 25);
|
||||
assert_eq!(usage.total_tokens, 35);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_message_stop_streaming() {
|
||||
let event = MessagesStreamEvent::MessageStop;
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert_eq!(choice.finish_reason, Some(FinishReason::Stop));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_ping_streaming() {
|
||||
let event = MessagesStreamEvent::Ping;
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk"));
|
||||
assert_eq!(openai_resp.choices.len(), 0); // Ping has no choices
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_streaming_role_start() {
|
||||
let openai_resp = ChatCompletionsStreamResponse {
|
||||
id: "chatcmpl-123".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta: MessageDelta {
|
||||
role: Some(Role::Assistant),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
|
||||
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
match anthropic_event {
|
||||
MessagesStreamEvent::MessageStart { message } => {
|
||||
assert_eq!(message.id, "chatcmpl-123");
|
||||
assert_eq!(message.role, MessagesRole::Assistant);
|
||||
assert_eq!(message.model, "gpt-4");
|
||||
}
|
||||
_ => panic!("Expected MessageStart event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_streaming_content_delta() {
|
||||
let openai_resp = ChatCompletionsStreamResponse {
|
||||
id: "chatcmpl-123".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta: MessageDelta {
|
||||
role: None,
|
||||
content: Some("Hello there!".to_string()),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
|
||||
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
match anthropic_event {
|
||||
MessagesStreamEvent::ContentBlockDelta { index, delta } => {
|
||||
assert_eq!(index, 0);
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => {
|
||||
assert_eq!(text, "Hello there!");
|
||||
}
|
||||
_ => panic!("Expected TextDelta"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected ContentBlockDelta event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_streaming_tool_calls() {
|
||||
let openai_resp = ChatCompletionsStreamResponse {
|
||||
id: "chatcmpl-123".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta: MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: Some("call_abc123".to_string()),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some("get_current_weather".to_string()),
|
||||
arguments: Some("".to_string()),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
|
||||
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
match anthropic_event {
|
||||
MessagesStreamEvent::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => {
|
||||
assert_eq!(index, 0);
|
||||
match content_block {
|
||||
MessagesContentBlock::ToolUse { id, name, .. } => {
|
||||
assert_eq!(id, "call_abc123");
|
||||
assert_eq!(name, "get_current_weather");
|
||||
}
|
||||
_ => panic!("Expected ToolUse content block"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected ContentBlockStart event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_streaming_final_usage() {
|
||||
let openai_resp = ChatCompletionsStreamResponse {
|
||||
id: "chatcmpl-123".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta: MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: Some(FinishReason::Stop),
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: Some(Usage {
|
||||
prompt_tokens: 15,
|
||||
completion_tokens: 30,
|
||||
total_tokens: 45,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
}),
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
|
||||
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
match anthropic_event {
|
||||
MessagesStreamEvent::MessageDelta { delta, usage } => {
|
||||
assert_eq!(delta.stop_reason, MessagesStopReason::EndTurn);
|
||||
assert_eq!(usage.input_tokens, 15);
|
||||
assert_eq!(usage.output_tokens, 30);
|
||||
}
|
||||
_ => panic!("Expected MessageDelta event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_empty_choices_to_anthropic_ping() {
|
||||
let openai_resp = ChatCompletionsStreamResponse {
|
||||
id: "chatcmpl-123".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![], // Empty choices
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
|
||||
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
match anthropic_event {
|
||||
MessagesStreamEvent::Ping => {
|
||||
// Expected behavior
|
||||
}
|
||||
_ => panic!("Expected Ping event for empty choices"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_roundtrip_consistency() {
|
||||
// Test that streaming events can roundtrip through conversions
|
||||
let original_event = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::TextDelta {
|
||||
text: "Test message".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
// Convert to OpenAI and back
|
||||
let openai_resp: ChatCompletionsStreamResponse = original_event.try_into().unwrap();
|
||||
let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
// Verify the roundtrip maintains the essential information
|
||||
match roundtrip_event {
|
||||
MessagesStreamEvent::ContentBlockDelta { index, delta } => {
|
||||
assert_eq!(index, 0);
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => {
|
||||
assert_eq!(text, "Test message");
|
||||
}
|
||||
_ => panic!("Expected TextDelta after roundtrip"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected ContentBlockDelta after roundtrip"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_tool_argument_accumulation() {
|
||||
// Test multiple tool argument deltas that should accumulate
|
||||
let tool_start = MessagesStreamEvent::ContentBlockStart {
|
||||
index: 0,
|
||||
content_block: MessagesContentBlock::ToolUse {
|
||||
id: "call_weather".to_string(),
|
||||
name: "get_weather".to_string(),
|
||||
input: json!({}),
|
||||
cache_control: None,
|
||||
},
|
||||
};
|
||||
|
||||
let arg_delta1 = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: r#"{"location": "#.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let arg_delta2 = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: r#"San Francisco", "unit": "fahrenheit"}"#.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
// Test that each delta converts properly to OpenAI format
|
||||
let openai_start: ChatCompletionsStreamResponse = tool_start.try_into().unwrap();
|
||||
let openai_delta1: ChatCompletionsStreamResponse = arg_delta1.try_into().unwrap();
|
||||
let openai_delta2: ChatCompletionsStreamResponse = arg_delta2.try_into().unwrap();
|
||||
|
||||
// Verify tool start
|
||||
let tool_calls = &openai_start.choices[0].delta.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls[0].id, Some("call_weather".to_string()));
|
||||
assert_eq!(
|
||||
tool_calls[0].function.as_ref().unwrap().name,
|
||||
Some("get_weather".to_string())
|
||||
);
|
||||
|
||||
// Verify argument deltas
|
||||
let args1 = &openai_delta1.choices[0].delta.tool_calls.as_ref().unwrap()[0]
|
||||
.function
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.arguments;
|
||||
assert_eq!(args1, &Some(r#"{"location": "#.to_string()));
|
||||
|
||||
let args2 = &openai_delta2.choices[0].delta.tool_calls.as_ref().unwrap()[0]
|
||||
.function
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.arguments;
|
||||
assert_eq!(
|
||||
args2,
|
||||
&Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_multiple_finish_reasons() {
|
||||
// Test different finish reasons in streaming
|
||||
let test_cases = vec![
|
||||
(MessagesStopReason::EndTurn, FinishReason::Stop),
|
||||
(MessagesStopReason::MaxTokens, FinishReason::Length),
|
||||
(MessagesStopReason::ToolUse, FinishReason::ToolCalls),
|
||||
(MessagesStopReason::StopSequence, FinishReason::Stop),
|
||||
];
|
||||
|
||||
for (anthropic_reason, expected_openai_reason) in test_cases {
|
||||
let event = MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_reason.clone(),
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 10,
|
||||
output_tokens: 20,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
assert_eq!(
|
||||
openai_resp.choices[0].finish_reason,
|
||||
Some(expected_openai_reason)
|
||||
);
|
||||
|
||||
// Test reverse conversion
|
||||
let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
match roundtrip_event {
|
||||
MessagesStreamEvent::MessageDelta { delta, .. } => {
|
||||
// Note: Some precision may be lost in roundtrip due to mapping differences
|
||||
assert!(matches!(
|
||||
delta.stop_reason,
|
||||
MessagesStopReason::EndTurn
|
||||
| MessagesStopReason::MaxTokens
|
||||
| MessagesStopReason::ToolUse
|
||||
| MessagesStopReason::StopSequence
|
||||
));
|
||||
}
|
||||
_ => panic!("Expected MessageDelta after roundtrip"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_error_handling() {
|
||||
// Test that malformed streaming events are handled gracefully
|
||||
let openai_resp_with_missing_data = ChatCompletionsStreamResponse {
|
||||
id: "test".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: 1234567890,
|
||||
model: "test".to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta: MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
|
||||
// Should convert to Ping when no meaningful content
|
||||
let anthropic_event: MessagesStreamEvent =
|
||||
openai_resp_with_missing_data.try_into().unwrap();
|
||||
assert!(matches!(anthropic_event, MessagesStreamEvent::Ping));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_content_block_stop() {
|
||||
let event = MessagesStreamEvent::ContentBlockStop { index: 0 };
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
// ContentBlockStop should produce an empty chunk
|
||||
assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk"));
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert_eq!(choice.delta.role, None);
|
||||
assert_eq!(choice.delta.content, None);
|
||||
assert_eq!(choice.delta.tool_calls, None);
|
||||
assert_eq!(choice.finish_reason, None);
|
||||
}
|
||||
}
|
||||
|
|
@ -6,18 +6,21 @@ pub mod clients;
|
|||
pub mod providers;
|
||||
pub mod transforms;
|
||||
// Re-export important types and traits
|
||||
pub use apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
pub use apis::sse::{SseEvent, SseStreamIter};
|
||||
pub use apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
pub use apis::streaming_shapes::sse::{SseEvent, SseStreamIter};
|
||||
pub use aws_smithy_eventstream::frame::DecodedFrame;
|
||||
pub use providers::id::ProviderId;
|
||||
pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
|
||||
pub use providers::response::{
|
||||
ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse,
|
||||
ProviderStreamResponseType, TokenUsage,
|
||||
ProviderResponse, ProviderResponseType, TokenUsage, ProviderResponseError
|
||||
};
|
||||
pub use providers::streaming_response::{
|
||||
ProviderStreamResponse, ProviderStreamResponseType
|
||||
};
|
||||
|
||||
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const OPENAI_RESPONSES_API_PATH: &str = "/v1/responses";
|
||||
pub const MESSAGES_PATH: &str = "/v1/messages";
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -42,9 +45,9 @@ mod tests {
|
|||
data: [DONE]
|
||||
"#;
|
||||
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||
let client_api =
|
||||
SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
let upstream_api =
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
|
||||
|
|
@ -79,9 +82,16 @@ mod tests {
|
|||
assert_eq!(stream_response.content_delta(), Some("Hello"));
|
||||
assert!(!stream_response.is_final());
|
||||
|
||||
// Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE])
|
||||
// Test that stream ends properly with [DONE]
|
||||
// The iterator should return the [DONE] event, then None
|
||||
let done_event = streaming_iter.next();
|
||||
assert!(done_event.is_some(), "Should get [DONE] event");
|
||||
let done_event = done_event.unwrap();
|
||||
assert!(done_event.is_done(), "[DONE] event should be marked as done");
|
||||
|
||||
// After [DONE], iterator should return None
|
||||
let final_event = streaming_iter.next();
|
||||
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
|
||||
assert!(final_event.is_none(), "Iterator should return None after [DONE]");
|
||||
}
|
||||
|
||||
/// Test AWS Event Stream decoding for Bedrock ConverseStream responses.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
|
||||
use crate::clients::endpoints::{SupportedAPIs, SupportedUpstreamAPIs};
|
||||
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use std::fmt::Display;
|
||||
|
||||
/// Provider identifier enum - simple enum for identifying providers
|
||||
|
|
@ -51,19 +51,24 @@ impl ProviderId {
|
|||
/// Given a client API, return the compatible upstream API for this provider
|
||||
pub fn compatible_api_for_client(
|
||||
&self,
|
||||
client_api: &SupportedAPIs,
|
||||
client_api: &SupportedAPIsFromClient,
|
||||
is_streaming: bool,
|
||||
) -> SupportedUpstreamAPIs {
|
||||
match (self, client_api) {
|
||||
// Claude/Anthropic providers natively support Anthropic APIs
|
||||
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
(ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
||||
}
|
||||
(
|
||||
ProviderId::Anthropic,
|
||||
SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// Anthropic doesn't support Responses API, fall back to chat completions
|
||||
(ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)
|
||||
}
|
||||
|
||||
// OpenAI-compatible providers only support OpenAI chat completions
|
||||
(
|
||||
ProviderId::OpenAI
|
||||
|
|
@ -80,7 +85,7 @@ impl ProviderId {
|
|||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
(
|
||||
|
|
@ -98,11 +103,16 @@ impl ProviderId {
|
|||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// OpenAI Responses API - only OpenAI supports this
|
||||
(ProviderId::OpenAI, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(OpenAIApi::Responses)
|
||||
}
|
||||
|
||||
// Amazon Bedrock natively supports Bedrock APIs
|
||||
(ProviderId::AmazonBedrock, SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
(ProviderId::AmazonBedrock, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
|
|
@ -111,7 +121,7 @@ impl ProviderId {
|
|||
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||
}
|
||||
}
|
||||
(ProviderId::AmazonBedrock, SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
(ProviderId::AmazonBedrock, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
|
|
@ -120,6 +130,20 @@ impl ProviderId {
|
|||
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||
}
|
||||
}
|
||||
(ProviderId::AmazonBedrock, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
)
|
||||
} else {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||
}
|
||||
}
|
||||
|
||||
// Non-OpenAI providers: if client requested the Responses API, fall back to Chat Completions
|
||||
(_, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@
|
|||
pub mod id;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
pub mod streaming_response;
|
||||
|
||||
pub use id::ProviderId;
|
||||
pub use request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
|
||||
pub use response::{ProviderResponse, ProviderResponseType, ProviderStreamResponse, TokenUsage};
|
||||
pub use response::{ProviderResponse, ProviderResponseType, TokenUsage};
|
||||
pub use streaming_response::{ProviderStreamResponse, ProviderStreamResponseType};
|
||||
|
|
|
|||
|
|
@ -2,19 +2,21 @@ use crate::apis::anthropic::MessagesRequest;
|
|||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
|
||||
use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::apis::openai_responses::ResponsesAPIRequest;
|
||||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||
use crate::clients::endpoints::SupportedUpstreamAPIs;
|
||||
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ProviderRequestType {
|
||||
ChatCompletionsRequest(ChatCompletionsRequest),
|
||||
MessagesRequest(MessagesRequest),
|
||||
BedrockConverse(ConverseRequest),
|
||||
BedrockConverseStream(ConverseStreamRequest),
|
||||
ResponsesAPIRequest(ResponsesAPIRequest),
|
||||
//add more request types here
|
||||
}
|
||||
pub trait ProviderRequest: Send + Sync {
|
||||
|
|
@ -49,6 +51,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.model(),
|
||||
Self::BedrockConverse(r) => r.model(),
|
||||
Self::BedrockConverseStream(r) => r.model(),
|
||||
Self::ResponsesAPIRequest(r) => r.model(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -58,6 +61,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.set_model(model),
|
||||
Self::BedrockConverse(r) => r.set_model(model),
|
||||
Self::BedrockConverseStream(r) => r.set_model(model),
|
||||
Self::ResponsesAPIRequest(r) => r.set_model(model),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -67,6 +71,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.is_streaming(),
|
||||
Self::BedrockConverse(_) => false,
|
||||
Self::BedrockConverseStream(_) => true,
|
||||
Self::ResponsesAPIRequest(r) => r.is_streaming(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -76,6 +81,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.extract_messages_text(),
|
||||
Self::BedrockConverse(r) => r.extract_messages_text(),
|
||||
Self::BedrockConverseStream(r) => r.extract_messages_text(),
|
||||
Self::ResponsesAPIRequest(r) => r.extract_messages_text(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -85,6 +91,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.get_recent_user_message(),
|
||||
Self::BedrockConverse(r) => r.get_recent_user_message(),
|
||||
Self::BedrockConverseStream(r) => r.get_recent_user_message(),
|
||||
Self::ResponsesAPIRequest(r) => r.get_recent_user_message(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -94,6 +101,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.to_bytes(),
|
||||
Self::BedrockConverse(r) => r.to_bytes(),
|
||||
Self::BedrockConverseStream(r) => r.to_bytes(),
|
||||
Self::ResponsesAPIRequest(r) => r.to_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -103,6 +111,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.metadata(),
|
||||
Self::BedrockConverse(r) => r.metadata(),
|
||||
Self::BedrockConverseStream(r) => r.metadata(),
|
||||
Self::ResponsesAPIRequest(r) => r.metadata(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -112,18 +121,19 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.remove_metadata_key(key),
|
||||
Self::BedrockConverse(r) => r.remove_metadata_key(key),
|
||||
Self::BedrockConverseStream(r) => r.remove_metadata_key(key),
|
||||
Self::ResponsesAPIRequest(r) => r.remove_metadata_key(key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the client API from a byte slice.
|
||||
impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
||||
impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, client_api): (&[u8], &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
fn try_from((bytes, client_api): (&[u8], &SupportedAPIsFromClient)) -> Result<Self, Self::Error> {
|
||||
// Use SupportedApi to determine the appropriate request type
|
||||
match client_api {
|
||||
SupportedAPIs::OpenAIChatCompletions(_) => {
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_) => {
|
||||
let chat_completion_request: ChatCompletionsRequest =
|
||||
ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
|
@ -131,11 +141,20 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
|||
chat_completion_request,
|
||||
))
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(_) => {
|
||||
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||
}
|
||||
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
|
||||
let responses_apirequest: ResponsesAPIRequest =
|
||||
ResponsesAPIRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ResponsesAPIRequest(
|
||||
responses_apirequest,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -148,17 +167,13 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
(client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs),
|
||||
) -> Result<Self, Self::Error> {
|
||||
match (client_request, upstream_api) {
|
||||
// Same API - no conversion needed, just clone the reference
|
||||
// ============================================================================
|
||||
// ChatCompletionsRequest conversions
|
||||
// ============================================================================
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)),
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
|
||||
|
||||
// Cross-API conversion - cloning is necessary for transformation
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
|
|
@ -173,7 +188,45 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
})?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseStreamRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock Stream request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(_),
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
) => {
|
||||
Err(ProviderRequestError {
|
||||
message: "Conversion from ChatCompletionsRequest to ResponsesAPIRequest is not supported. ResponsesAPI can only be used as a client API, not as an upstream API.".to_string(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MessagesRequest conversions
|
||||
// ============================================================================
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
|
|
@ -189,31 +242,6 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
})?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
|
||||
// Cross-API conversions: OpenAI/Anthropic to Amazon Bedrock
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseStreamRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
|
|
@ -235,7 +263,97 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
let bedrock_req = ConverseStreamRequest::try_from(messages_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
|
||||
"Failed to convert MessagesRequest to Amazon Bedrock Stream request: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(_),
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
) => {
|
||||
Err(ProviderRequestError {
|
||||
message: "Conversion from MessagesRequest to ResponsesAPIRequest is not supported. ResponsesAPI can only be used as a client API, not as an upstream API.".to_string(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ResponsesAPIRequest conversions (only converts TO other formats)
|
||||
// ============================================================================
|
||||
(
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
) => Ok(ProviderRequestType::ResponsesAPIRequest(responses_req)),
|
||||
|
||||
// ResponsesAPI -> ChatCompletions (direct conversion)
|
||||
(
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
|
||||
// ResponsesAPI -> Anthropic Messages (via ChatCompletions)
|
||||
(
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
// Chain: ResponsesAPI -> ChatCompletions -> MessagesRequest
|
||||
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
|
||||
let messages_req = MessagesRequest::try_from(chat_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to MessagesRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
|
||||
// ResponsesAPI -> Bedrock Converse (via ChatCompletions)
|
||||
(
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
) => {
|
||||
// Chain: ResponsesAPI -> ChatCompletions -> ConverseRequest
|
||||
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
|
||||
let bedrock_req = ConverseRequest::try_from(chat_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
|
|
@ -244,13 +362,50 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
|
||||
// Amazon Bedrock to other APIs conversions
|
||||
// ResponsesAPI -> Bedrock Converse Stream (via ChatCompletions)
|
||||
(
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
) => {
|
||||
// Chain: ResponsesAPI -> ChatCompletions -> ConverseStreamRequest
|
||||
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
|
||||
let bedrock_req = ConverseStreamRequest::try_from(chat_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to Amazon Bedrock Stream request: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Amazon Bedrock conversions (not supported as client API)
|
||||
// ============================================================================
|
||||
|
||||
(ProviderRequestType::BedrockConverse(_), _) => {
|
||||
todo!("Amazon Bedrock to ChatCompletionsRequest conversion not implemented yet")
|
||||
Err(ProviderRequestError {
|
||||
message: "Amazon Bedrock Converse is not supported as a client API. Only OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses APIs are supported as client APIs.".to_string(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
(ProviderRequestType::BedrockConverseStream(_), _) => {
|
||||
todo!("Amazon Bedrock Stream to ChatCompletionsRequest conversion not implemented yet")
|
||||
Err(ProviderRequestError {
|
||||
message: "Amazon Bedrock Converse Stream is not supported as a client API. Only OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses APIs are supported as client APIs.".to_string(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -284,7 +439,7 @@ mod tests {
|
|||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use serde_json::json;
|
||||
|
||||
|
|
@ -298,7 +453,7 @@ mod tests {
|
|||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let api = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &api));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
|
|
@ -321,7 +476,7 @@ mod tests {
|
|||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let endpoint = SupportedAPIs::AnthropicMessagesAPI(Messages);
|
||||
let endpoint = SupportedAPIsFromClient::AnthropicMessagesAPI(Messages);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
|
|
@ -343,7 +498,7 @@ mod tests {
|
|||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let endpoint = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
|
|
@ -366,7 +521,7 @@ mod tests {
|
|||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
// Intentionally use OpenAI endpoint for Anthropic payload
|
||||
let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let endpoint = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
|
||||
// Should parse as ChatCompletionsRequest, not error
|
||||
assert!(result.is_ok());
|
||||
|
|
@ -486,4 +641,272 @@ mod tests {
|
|||
let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens);
|
||||
assert_eq!(original_max_tokens, roundtrip_max_tokens);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_responses_api_request_from_bytes() {
|
||||
use crate::apis::openai::OpenAIApi::Responses;
|
||||
|
||||
let req = json!({
|
||||
"model": "gpt-4o",
|
||||
"input": "Hello, how are you?"
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let api = SupportedAPIsFromClient::OpenAIResponsesAPI(Responses);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &api));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::ResponsesAPIRequest(r) => {
|
||||
assert_eq!(r.model, "gpt-4o");
|
||||
}
|
||||
_ => panic!("Expected ResponsesAPIRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_responses_api_to_chat_completions_conversion() {
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest};
|
||||
|
||||
let responses_req = ResponsesAPIRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
input: InputParam::Text("Hello, world!".to_string()),
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(0.9),
|
||||
max_output_tokens: Some(100),
|
||||
stream: Some(false),
|
||||
metadata: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
instructions: None,
|
||||
modalities: None,
|
||||
user: None,
|
||||
store: None,
|
||||
reasoning_effort: None,
|
||||
include: None,
|
||||
audio: None,
|
||||
text: None,
|
||||
service_tier: None,
|
||||
top_logprobs: None,
|
||||
stream_options: None,
|
||||
truncation: None,
|
||||
conversation: None,
|
||||
previous_response_id: None,
|
||||
max_tool_calls: None,
|
||||
background: None,
|
||||
};
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
&upstream_api,
|
||||
));
|
||||
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req) => {
|
||||
assert_eq!(chat_req.model, "gpt-4o");
|
||||
assert_eq!(chat_req.temperature, Some(0.7));
|
||||
assert_eq!(chat_req.top_p, Some(0.9));
|
||||
assert_eq!(chat_req.max_completion_tokens, Some(100));
|
||||
assert_eq!(chat_req.messages.len(), 1);
|
||||
}
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_responses_api_to_anthropic_messages_conversion() {
|
||||
use crate::apis::anthropic::AnthropicApi::Messages;
|
||||
use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest};
|
||||
|
||||
let responses_req = ResponsesAPIRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
input: InputParam::Text("Hello, Claude!".to_string()),
|
||||
temperature: Some(0.8),
|
||||
max_output_tokens: Some(150),
|
||||
stream: Some(false),
|
||||
metadata: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
instructions: Some("You are a helpful assistant".to_string()),
|
||||
modalities: None,
|
||||
user: None,
|
||||
store: None,
|
||||
reasoning_effort: None,
|
||||
include: None,
|
||||
audio: None,
|
||||
text: None,
|
||||
service_tier: None,
|
||||
top_p: None,
|
||||
top_logprobs: None,
|
||||
stream_options: None,
|
||||
truncation: None,
|
||||
conversation: None,
|
||||
previous_response_id: None,
|
||||
max_tool_calls: None,
|
||||
background: None,
|
||||
};
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(Messages);
|
||||
let result = ProviderRequestType::try_from((
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
&upstream_api,
|
||||
));
|
||||
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::MessagesRequest(messages_req) => {
|
||||
assert_eq!(messages_req.model, "gpt-4o");
|
||||
assert_eq!(messages_req.temperature, Some(0.8));
|
||||
assert_eq!(messages_req.max_tokens, 150);
|
||||
// Instructions should be converted to system prompt via ChatCompletions conversion
|
||||
// The conversion chain: ResponsesAPI -> ChatCompletions (system message) -> Anthropic (system prompt)
|
||||
// But we need to check if the system prompt was actually set
|
||||
assert_eq!(messages_req.messages.len(), 1);
|
||||
}
|
||||
_ => panic!("Expected MessagesRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_responses_api_to_bedrock_conversion() {
|
||||
use crate::apis::amazon_bedrock::AmazonBedrockApi::Converse;
|
||||
use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest};
|
||||
|
||||
let responses_req = ResponsesAPIRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
input: InputParam::Text("Hello, Bedrock!".to_string()),
|
||||
temperature: Some(0.5),
|
||||
max_output_tokens: Some(200),
|
||||
stream: Some(false),
|
||||
metadata: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
instructions: None,
|
||||
modalities: None,
|
||||
user: None,
|
||||
store: None,
|
||||
reasoning_effort: None,
|
||||
include: None,
|
||||
audio: None,
|
||||
text: None,
|
||||
service_tier: None,
|
||||
top_p: None,
|
||||
top_logprobs: None,
|
||||
stream_options: None,
|
||||
truncation: None,
|
||||
conversation: None,
|
||||
previous_response_id: None,
|
||||
max_tool_calls: None,
|
||||
background: None,
|
||||
};
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverse(Converse);
|
||||
let result = ProviderRequestType::try_from((
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
&upstream_api,
|
||||
));
|
||||
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::BedrockConverse(bedrock_req) => {
|
||||
assert_eq!(bedrock_req.model_id, "gpt-4o");
|
||||
// Bedrock receives the converted request through ChatCompletions
|
||||
assert!(!bedrock_req.messages.is_none());
|
||||
}
|
||||
_ => panic!("Expected BedrockConverse variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_to_responses_api_not_supported() {
|
||||
use crate::apis::openai::OpenAIApi::Responses;
|
||||
use crate::apis::openai::{Message, MessageContent, Role};
|
||||
|
||||
let chat_req = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIResponsesAPI(Responses);
|
||||
let result = ProviderRequestType::try_from((
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
&upstream_api,
|
||||
));
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.message.contains("ResponsesAPI can only be used as a client API"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_messages_to_responses_api_not_supported() {
|
||||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::OpenAIApi::Responses;
|
||||
|
||||
let messages_req = AnthropicMessagesRequest {
|
||||
model: "claude-3-sonnet".to_string(),
|
||||
messages: vec![crate::apis::anthropic::MessagesMessage {
|
||||
role: crate::apis::anthropic::MessagesRole::User,
|
||||
content: crate::apis::anthropic::MessagesMessageContent::Single(
|
||||
"Hello!".to_string(),
|
||||
),
|
||||
}],
|
||||
max_tokens: 100,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
stream: None,
|
||||
stop_sequences: None,
|
||||
system: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIResponsesAPI(Responses);
|
||||
let result = ProviderRequestType::try_from((
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
&upstream_api,
|
||||
));
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.message.contains("ResponsesAPI can only be used as a client API"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bedrock_as_client_api_not_supported() {
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
|
||||
// Create a simple Bedrock request (we'll use Default if available, or minimal construction)
|
||||
let bedrock_req = ConverseRequest::default();
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((
|
||||
ProviderRequestType::BedrockConverse(bedrock_req),
|
||||
&upstream_api,
|
||||
));
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.message.contains("not supported as a client API"));
|
||||
assert!(err
|
||||
.message
|
||||
.contains("OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
1348
crates/hermesllm/src/providers/streaming_response.rs
Normal file
1348
crates/hermesllm/src/providers/streaming_response.rs
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -11,11 +11,13 @@
|
|||
pub mod lib;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
pub mod response_streaming;
|
||||
|
||||
// Re-export commonly used items for convenience
|
||||
pub use lib::*;
|
||||
pub use request::*;
|
||||
pub use response::*;
|
||||
pub use response_streaming::*;
|
||||
|
||||
// ============================================================================
|
||||
// CONSTANTS
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@ use crate::apis::anthropic::{
|
|||
use crate::apis::openai::{
|
||||
ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType,
|
||||
};
|
||||
|
||||
use crate::apis::openai_responses::{
|
||||
ResponsesAPIRequest, InputContent, InputItem, InputParam, MessageRole, Modality, ReasoningEffort, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice
|
||||
};
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::transforms::lib::*;
|
||||
|
|
@ -244,6 +248,202 @@ impl TryFrom<Message> for BedrockMessage {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(req: ResponsesAPIRequest) -> Result<Self, Self::Error> {
|
||||
|
||||
// Convert input to messages
|
||||
let messages = match req.input {
|
||||
InputParam::Text(text) => {
|
||||
// Simple text input becomes a user message
|
||||
vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text(text),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}]
|
||||
}
|
||||
InputParam::Items(items) => {
|
||||
// Convert input items to messages
|
||||
let mut converted_messages = Vec::new();
|
||||
|
||||
// Add instructions as system message if present
|
||||
if let Some(instructions) = &req.instructions {
|
||||
converted_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(instructions.clone()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Convert each input item
|
||||
for item in items {
|
||||
match item {
|
||||
InputItem::Message(input_msg) => {
|
||||
let role = match input_msg.role {
|
||||
MessageRole::User => Role::User,
|
||||
MessageRole::Assistant => Role::Assistant,
|
||||
MessageRole::System => Role::System,
|
||||
MessageRole::Developer => Role::System, // Map developer to system
|
||||
};
|
||||
|
||||
// Convert content blocks
|
||||
let content = if input_msg.content.len() == 1 {
|
||||
// Single content item - check if it's simple text
|
||||
match &input_msg.content[0] {
|
||||
InputContent::InputText { text } => MessageContent::Text(text.clone()),
|
||||
_ => {
|
||||
// Convert to parts for non-text content
|
||||
MessageContent::Parts(
|
||||
input_msg.content.iter()
|
||||
.filter_map(|c| match c {
|
||||
InputContent::InputText { text } => {
|
||||
Some(crate::apis::openai::ContentPart::Text { text: text.clone() })
|
||||
}
|
||||
InputContent::InputImage { image_url, .. } => {
|
||||
Some(crate::apis::openai::ContentPart::ImageUrl {
|
||||
image_url: crate::apis::openai::ImageUrl {
|
||||
url: image_url.clone(),
|
||||
detail: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
InputContent::InputFile { .. } => None, // Skip files for now
|
||||
InputContent::InputAudio { .. } => None, // Skip audio for now
|
||||
})
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Multiple content items - convert to parts
|
||||
MessageContent::Parts(
|
||||
input_msg.content.iter()
|
||||
.filter_map(|c| match c {
|
||||
InputContent::InputText { text } => {
|
||||
Some(crate::apis::openai::ContentPart::Text { text: text.clone() })
|
||||
}
|
||||
InputContent::InputImage { image_url, .. } => {
|
||||
Some(crate::apis::openai::ContentPart::ImageUrl {
|
||||
image_url: crate::apis::openai::ImageUrl {
|
||||
url: image_url.clone(),
|
||||
detail: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
InputContent::InputFile { .. } => None, // Skip files for now
|
||||
InputContent::InputAudio { .. } => None, // Skip audio for now
|
||||
})
|
||||
.collect()
|
||||
)
|
||||
};
|
||||
|
||||
converted_messages.push(Message {
|
||||
role,
|
||||
content,
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
converted_messages
|
||||
}
|
||||
};
|
||||
|
||||
// Build the ChatCompletionsRequest
|
||||
Ok(ChatCompletionsRequest {
|
||||
model: req.model,
|
||||
messages,
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
max_completion_tokens: req.max_output_tokens.map(|t| t as u32),
|
||||
stream: req.stream,
|
||||
metadata: req.metadata,
|
||||
user: req.user,
|
||||
store: req.store,
|
||||
service_tier: req.service_tier,
|
||||
top_logprobs: req.top_logprobs.map(|t| t as u32),
|
||||
modalities: req.modalities.map(|mods| {
|
||||
mods.into_iter().map(|m| {
|
||||
match m {
|
||||
Modality::Text => "text".to_string(),
|
||||
Modality::Audio => "audio".to_string(),
|
||||
}
|
||||
}).collect()
|
||||
}),
|
||||
stream_options: req.stream_options.map(|opts| {
|
||||
crate::apis::openai::StreamOptions {
|
||||
include_usage: opts.include_usage,
|
||||
}
|
||||
}),
|
||||
reasoning_effort: req.reasoning_effort.map(|effort| {
|
||||
match effort {
|
||||
ReasoningEffort::Low => "low".to_string(),
|
||||
ReasoningEffort::Medium => "medium".to_string(),
|
||||
ReasoningEffort::High => "high".to_string(),
|
||||
}
|
||||
}),
|
||||
tools: req.tools.map(|tools| {
|
||||
tools.into_iter().map(|tool| {
|
||||
|
||||
// Only convert Function tools - other types are not supported in ChatCompletions
|
||||
match tool {
|
||||
ResponsesTool::Function { name, description, parameters, strict } => Ok(Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: crate::apis::openai::Function {
|
||||
name,
|
||||
description,
|
||||
parameters: parameters.unwrap_or_else(|| serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
})),
|
||||
strict,
|
||||
}
|
||||
}),
|
||||
ResponsesTool::FileSearch { .. } => Err(TransformError::UnsupportedConversion(
|
||||
"FileSearch tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
|
||||
)),
|
||||
ResponsesTool::WebSearchPreview { .. } => Err(TransformError::UnsupportedConversion(
|
||||
"WebSearchPreview tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
|
||||
)),
|
||||
ResponsesTool::CodeInterpreter => Err(TransformError::UnsupportedConversion(
|
||||
"CodeInterpreter tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
|
||||
)),
|
||||
ResponsesTool::Computer { .. } => Err(TransformError::UnsupportedConversion(
|
||||
"Computer tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
|
||||
)),
|
||||
}
|
||||
}).collect::<Result<Vec<_>, _>>()
|
||||
}).transpose()?,
|
||||
tool_choice: req.tool_choice.map(|choice| {
|
||||
match choice {
|
||||
ResponsesToolChoice::String(s) => {
|
||||
match s.as_str() {
|
||||
"auto" => ToolChoice::Type(ToolChoiceType::Auto),
|
||||
"required" => ToolChoice::Type(ToolChoiceType::Required),
|
||||
"none" => ToolChoice::Type(ToolChoiceType::None),
|
||||
_ => ToolChoice::Type(ToolChoiceType::Auto), // Default to auto for unknown strings
|
||||
}
|
||||
}
|
||||
ResponsesToolChoice::Named { function, .. } => ToolChoice::Function {
|
||||
choice_type: "function".to_string(),
|
||||
function: crate::apis::openai::FunctionChoice { name: function.name }
|
||||
}
|
||||
}
|
||||
}),
|
||||
parallel_tool_calls: req.parallel_tool_calls,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,11 @@
|
|||
use crate::apis::amazon_bedrock::{
|
||||
ContentBlockDelta, ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason,
|
||||
};
|
||||
use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesResponse,
|
||||
MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
|
||||
};
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsResponse, ChatCompletionsStreamResponse, Role, ToolCallDelta,
|
||||
MessagesContentBlock, MessagesResponse,
|
||||
MessagesRole, MessagesStopReason, MessagesUsage,
|
||||
};
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::*;
|
||||
use serde_json::Value;
|
||||
|
||||
// ============================================================================
|
||||
// STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for convenience
|
||||
|
|
@ -120,289 +115,6 @@ impl TryFrom<ConverseResponse> for MessagesResponse {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ChatCompletionsStreamResponse> for MessagesStreamEvent {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(resp: ChatCompletionsStreamResponse) -> Result<Self, Self::Error> {
|
||||
if resp.choices.is_empty() {
|
||||
return Ok(MessagesStreamEvent::Ping);
|
||||
}
|
||||
|
||||
let choice = &resp.choices[0];
|
||||
|
||||
// Handle final chunk with usage
|
||||
let has_usage = resp.usage.is_some();
|
||||
if let Some(usage) = resp.usage {
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into();
|
||||
return Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_stop_reason,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: usage.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle role start
|
||||
if let Some(Role::Assistant) = choice.delta.role {
|
||||
return Ok(MessagesStreamEvent::MessageStart {
|
||||
message: MessagesStreamMessage {
|
||||
id: resp.id,
|
||||
obj_type: "message".to_string(),
|
||||
role: MessagesRole::Assistant,
|
||||
content: vec![],
|
||||
model: resp.model,
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Handle content delta
|
||||
if let Some(content) = &choice.delta.content {
|
||||
if !content.is_empty() {
|
||||
return Ok(MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::TextDelta {
|
||||
text: content.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if let Some(tool_calls) = &choice.delta.tool_calls {
|
||||
return convert_tool_call_deltas(tool_calls.clone());
|
||||
}
|
||||
|
||||
// Handle finish reason - generate MessageDelta only (MessageStop comes later)
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
// If we have usage data, it was already handled above
|
||||
// If not, we need to generate MessageDelta with default usage
|
||||
if !has_usage {
|
||||
let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into();
|
||||
return Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_stop_reason,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
});
|
||||
}
|
||||
// If usage was already handled above, we don't need to do anything more here
|
||||
// MessageStop will be handled when [DONE] is encountered
|
||||
}
|
||||
|
||||
// Default to ping for unhandled cases
|
||||
Ok(MessagesStreamEvent::Ping)
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<String> for MessagesStreamEvent {
|
||||
fn into(self) -> String {
|
||||
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
|
||||
let event_type = match &self {
|
||||
MessagesStreamEvent::MessageStart { .. } => "message_start",
|
||||
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
|
||||
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop",
|
||||
MessagesStreamEvent::MessageDelta { .. } => "message_delta",
|
||||
MessagesStreamEvent::MessageStop => "message_stop",
|
||||
MessagesStreamEvent::Ping => "ping",
|
||||
};
|
||||
|
||||
let event = format!("event: {}\n", event_type);
|
||||
let data = format!("data: {}\n\n", transformed_json);
|
||||
event + &data
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(event: ConverseStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
// MessageStart - convert to Anthropic MessageStart
|
||||
ConverseStreamEvent::MessageStart(start_event) => {
|
||||
let role = match start_event.role {
|
||||
crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User,
|
||||
crate::apis::amazon_bedrock::ConversationRole::Assistant => {
|
||||
MessagesRole::Assistant
|
||||
}
|
||||
};
|
||||
|
||||
Ok(MessagesStreamEvent::MessageStart {
|
||||
message: MessagesStreamMessage {
|
||||
id: format!(
|
||||
"bedrock-stream-{}",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos()
|
||||
),
|
||||
obj_type: "message".to_string(),
|
||||
role,
|
||||
content: vec![],
|
||||
model: "bedrock-model".to_string(),
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ContentBlockStart - convert to Anthropic ContentBlockStart
|
||||
ConverseStreamEvent::ContentBlockStart(start_event) => {
|
||||
// Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas
|
||||
// Anthropic expects the same pattern, so we initialize with an empty input object
|
||||
match start_event.start {
|
||||
crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => {
|
||||
Ok(MessagesStreamEvent::ContentBlockStart {
|
||||
index: start_event.content_block_index as u32,
|
||||
content_block: MessagesContentBlock::ToolUse {
|
||||
id: tool_use.tool_use_id,
|
||||
name: tool_use.name,
|
||||
input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas
|
||||
cache_control: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ContentBlockDelta - convert to Anthropic ContentBlockDelta
|
||||
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
|
||||
let delta = match delta_event.delta {
|
||||
ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text },
|
||||
ContentBlockDelta::ToolUse { tool_use } => {
|
||||
MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: tool_use.input,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(MessagesStreamEvent::ContentBlockDelta {
|
||||
index: delta_event.content_block_index as u32,
|
||||
delta,
|
||||
})
|
||||
}
|
||||
|
||||
// ContentBlockStop - convert to Anthropic ContentBlockStop
|
||||
ConverseStreamEvent::ContentBlockStop(stop_event) => {
|
||||
Ok(MessagesStreamEvent::ContentBlockStop {
|
||||
index: stop_event.content_block_index as u32,
|
||||
})
|
||||
}
|
||||
|
||||
// MessageStop - convert to Anthropic MessageDelta with stop reason + MessageStop
|
||||
ConverseStreamEvent::MessageStop(stop_event) => {
|
||||
let anthropic_stop_reason = match stop_event.stop_reason {
|
||||
StopReason::EndTurn => MessagesStopReason::EndTurn,
|
||||
StopReason::ToolUse => MessagesStopReason::ToolUse,
|
||||
StopReason::MaxTokens => MessagesStopReason::MaxTokens,
|
||||
StopReason::StopSequence => MessagesStopReason::EndTurn,
|
||||
StopReason::GuardrailIntervened => MessagesStopReason::Refusal,
|
||||
StopReason::ContentFiltered => MessagesStopReason::Refusal,
|
||||
};
|
||||
|
||||
// Return MessageDelta (MessageStop will be sent separately by the streaming handler)
|
||||
Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_stop_reason,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Metadata - convert usage information to MessageDelta
|
||||
ConverseStreamEvent::Metadata(metadata_event) => {
|
||||
Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: MessagesStopReason::EndTurn,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: metadata_event.usage.input_tokens,
|
||||
output_tokens: metadata_event.usage.output_tokens,
|
||||
cache_creation_input_tokens: metadata_event.usage.cache_write_input_tokens,
|
||||
cache_read_input_tokens: metadata_event.usage.cache_read_input_tokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Exception events - convert to Ping (could be enhanced to return error events)
|
||||
ConverseStreamEvent::InternalServerException(_)
|
||||
| ConverseStreamEvent::ModelStreamErrorException(_)
|
||||
| ConverseStreamEvent::ServiceUnavailableException(_)
|
||||
| ConverseStreamEvent::ThrottlingException(_)
|
||||
| ConverseStreamEvent::ValidationException(_) => {
|
||||
// TODO: Consider adding proper error handling/events
|
||||
Ok(MessagesStreamEvent::Ping)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert tool call deltas to Anthropic stream events
|
||||
fn convert_tool_call_deltas(
|
||||
tool_calls: Vec<ToolCallDelta>,
|
||||
) -> Result<MessagesStreamEvent, TransformError> {
|
||||
for tool_call in tool_calls {
|
||||
if let Some(id) = &tool_call.id {
|
||||
// Tool call start
|
||||
if let Some(function) = &tool_call.function {
|
||||
if let Some(name) = &function.name {
|
||||
return Ok(MessagesStreamEvent::ContentBlockStart {
|
||||
index: tool_call.index,
|
||||
content_block: MessagesContentBlock::ToolUse {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
input: Value::Object(serde_json::Map::new()),
|
||||
cache_control: None,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if let Some(function) = &tool_call.function {
|
||||
if let Some(arguments) = &function.arguments {
|
||||
// Tool arguments delta
|
||||
return Ok(MessagesStreamEvent::ContentBlockDelta {
|
||||
index: tool_call.index,
|
||||
delta: MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: arguments.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to ping if no valid tool call found
|
||||
Ok(MessagesStreamEvent::Ping)
|
||||
}
|
||||
|
||||
/// Convert Bedrock Message to Anthropic content blocks
|
||||
///
|
||||
|
|
|
|||
|
|
@ -1,15 +1,13 @@
|
|||
use crate::apis::amazon_bedrock::{
|
||||
ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason,
|
||||
ConverseOutput, ConverseResponse, StopReason,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesResponse, MessagesStopReason,
|
||||
MessagesStreamEvent, MessagesUsage,
|
||||
MessagesContentBlock, MessagesResponse, MessagesUsage,
|
||||
};
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason,
|
||||
FunctionCallDelta, MessageContent, MessageDelta, ResponseMessage, Role, StreamChoice,
|
||||
ToolCallDelta, Usage,
|
||||
ChatCompletionsResponse, Choice, FinishReason, MessageContent, ResponseMessage, Role, Usage,
|
||||
};
|
||||
use crate::apis::openai_responses::ResponsesAPIResponse;
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::*;
|
||||
|
||||
|
|
@ -30,6 +28,163 @@ impl Into<Usage> for MessagesUsage {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ChatCompletionsResponse> for ResponsesAPIResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(resp: ChatCompletionsResponse) -> Result<Self, Self::Error> {
|
||||
use crate::apis::openai_responses::{
|
||||
IncompleteDetails, IncompleteReason, OutputContent, OutputItem, OutputItemStatus,
|
||||
ResponseStatus, ResponseUsage, ResponsesAPIResponse,
|
||||
};
|
||||
|
||||
// Convert the first choice's message to output items
|
||||
let output = if let Some(choice) = resp.choices.first() {
|
||||
let mut items = Vec::new();
|
||||
|
||||
// Create a message output item from the response message
|
||||
let mut content = Vec::new();
|
||||
|
||||
// Add text content if present
|
||||
if let Some(text) = &choice.message.content {
|
||||
content.push(OutputContent::OutputText {
|
||||
text: text.clone(),
|
||||
annotations: vec![],
|
||||
logprobs: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Add audio content if present (audio is a Value, need to handle it carefully)
|
||||
if let Some(audio) = &choice.message.audio {
|
||||
// Audio is serde_json::Value, try to extract data and transcript
|
||||
if let Some(audio_obj) = audio.as_object() {
|
||||
let data = audio_obj
|
||||
.get("data")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
let transcript = audio_obj
|
||||
.get("transcript")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
content.push(OutputContent::OutputAudio { data, transcript });
|
||||
}
|
||||
}
|
||||
|
||||
// Add refusal content if present
|
||||
if let Some(refusal) = &choice.message.refusal {
|
||||
content.push(OutputContent::Refusal {
|
||||
refusal: refusal.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// Only add the message item if there's actual content (text, audio, or refusal)
|
||||
// Don't add empty message items when there are only tool calls
|
||||
if !content.is_empty() {
|
||||
items.push(OutputItem::Message {
|
||||
id: format!("msg_{}", resp.id),
|
||||
status: OutputItemStatus::Completed,
|
||||
role: match choice.message.role {
|
||||
Role::User => "user".to_string(),
|
||||
Role::Assistant => "assistant".to_string(),
|
||||
Role::System => "system".to_string(),
|
||||
Role::Tool => "tool".to_string(),
|
||||
},
|
||||
content,
|
||||
});
|
||||
}
|
||||
|
||||
// Add tool calls as function call items if present
|
||||
if let Some(tool_calls) = &choice.message.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
items.push(OutputItem::FunctionCall {
|
||||
id: format!("func_{}", tool_call.id),
|
||||
status: OutputItemStatus::Completed,
|
||||
call_id: tool_call.id.clone(),
|
||||
name: Some(tool_call.function.name.clone()),
|
||||
arguments: Some(tool_call.function.arguments.clone()),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
items
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Convert finish_reason to status
|
||||
let status = if let Some(choice) = resp.choices.first() {
|
||||
match choice.finish_reason {
|
||||
Some(FinishReason::Stop) => ResponseStatus::Completed,
|
||||
Some(FinishReason::ToolCalls) => ResponseStatus::Completed,
|
||||
Some(FinishReason::Length) => ResponseStatus::Incomplete,
|
||||
Some(FinishReason::ContentFilter) => ResponseStatus::Failed,
|
||||
_ => ResponseStatus::Completed,
|
||||
}
|
||||
} else {
|
||||
ResponseStatus::Completed
|
||||
};
|
||||
|
||||
// Convert usage
|
||||
let usage = ResponseUsage {
|
||||
input_tokens: resp.usage.prompt_tokens as i32,
|
||||
output_tokens: resp.usage.completion_tokens as i32,
|
||||
total_tokens: resp.usage.total_tokens as i32,
|
||||
input_tokens_details: resp.usage.prompt_tokens_details.map(|details| {
|
||||
crate::apis::openai_responses::TokenDetails {
|
||||
cached_tokens: details.cached_tokens.unwrap_or(0) as i32,
|
||||
}
|
||||
}),
|
||||
output_tokens_details: resp.usage.completion_tokens_details.map(|details| {
|
||||
crate::apis::openai_responses::OutputTokenDetails {
|
||||
reasoning_tokens: details.reasoning_tokens.unwrap_or(0) as i32,
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
// Set incomplete_details if status is incomplete
|
||||
let incomplete_details = if matches!(status, ResponseStatus::Incomplete) {
|
||||
Some(IncompleteDetails {
|
||||
reason: IncompleteReason::MaxOutputTokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(ResponsesAPIResponse {
|
||||
id: resp.id,
|
||||
object: "response".to_string(),
|
||||
created_at: resp.created as i64,
|
||||
status,
|
||||
background: Some(false),
|
||||
error: None,
|
||||
incomplete_details,
|
||||
instructions: None,
|
||||
max_output_tokens: None,
|
||||
max_tool_calls: None,
|
||||
model: resp.model,
|
||||
output,
|
||||
usage: Some(usage),
|
||||
parallel_tool_calls: true,
|
||||
conversation: None,
|
||||
previous_response_id: None,
|
||||
tools: vec![],
|
||||
tool_choice: "auto".to_string(),
|
||||
temperature: 1.0,
|
||||
top_p: 1.0,
|
||||
metadata: resp.metadata.unwrap_or_default(),
|
||||
truncation: None,
|
||||
reasoning: None,
|
||||
store: None,
|
||||
text: None,
|
||||
audio: None,
|
||||
modalities: None,
|
||||
service_tier: resp.service_tier,
|
||||
top_logprobs: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
|
|
@ -173,416 +328,6 @@ impl TryFrom<ConverseResponse> for ChatCompletionsResponse {
|
|||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// STREAMING TRANSFORMATIONS
|
||||
// ============================================================================
|
||||
|
||||
impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(event: MessagesStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk(
|
||||
&message.id,
|
||||
&message.model,
|
||||
MessageDelta {
|
||||
role: Some(Role::Assistant),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
|
||||
MessagesStreamEvent::ContentBlockStart { content_block, .. } => {
|
||||
convert_content_block_start(content_block)
|
||||
}
|
||||
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta),
|
||||
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
|
||||
|
||||
MessagesStreamEvent::MessageDelta { delta, usage } => {
|
||||
let finish_reason: Option<FinishReason> = Some(delta.stop_reason.into());
|
||||
let openai_usage: Option<Usage> = Some(usage.into());
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason,
|
||||
openai_usage,
|
||||
))
|
||||
}
|
||||
|
||||
MessagesStreamEvent::MessageStop => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Some(FinishReason::Stop),
|
||||
None,
|
||||
)),
|
||||
|
||||
MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse {
|
||||
id: "stream".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: current_timestamp(),
|
||||
model: "unknown".to_string(),
|
||||
choices: vec![],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(event: ConverseStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
ConverseStreamEvent::MessageStart(start_event) => {
|
||||
let role = match start_event.role {
|
||||
crate::apis::amazon_bedrock::ConversationRole::User => Role::User,
|
||||
crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant,
|
||||
};
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: Some(role),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
ConverseStreamEvent::ContentBlockStart(start_event) => {
|
||||
use crate::apis::amazon_bedrock::ContentBlockStart;
|
||||
|
||||
match start_event.start {
|
||||
ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: start_event.content_block_index as u32,
|
||||
id: Some(tool_use.tool_use_id),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some(tool_use.name),
|
||||
arguments: Some("".to_string()),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
|
||||
use crate::apis::amazon_bedrock::ContentBlockDelta;
|
||||
|
||||
match delta_event.delta {
|
||||
ContentBlockDelta::Text { text } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: delta_event.content_block_index as u32,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(tool_use.input),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()),
|
||||
|
||||
ConverseStreamEvent::MessageStop(stop_event) => {
|
||||
let finish_reason = match stop_event.stop_reason {
|
||||
StopReason::EndTurn => FinishReason::Stop,
|
||||
StopReason::ToolUse => FinishReason::ToolCalls,
|
||||
StopReason::MaxTokens => FinishReason::Length,
|
||||
StopReason::StopSequence => FinishReason::Stop,
|
||||
StopReason::GuardrailIntervened => FinishReason::ContentFilter,
|
||||
StopReason::ContentFiltered => FinishReason::ContentFilter,
|
||||
};
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Some(finish_reason),
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
ConverseStreamEvent::Metadata(metadata_event) => {
|
||||
let usage = Usage {
|
||||
prompt_tokens: metadata_event.usage.input_tokens,
|
||||
completion_tokens: metadata_event.usage.output_tokens,
|
||||
total_tokens: metadata_event.usage.total_tokens,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
};
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
Some(usage),
|
||||
))
|
||||
}
|
||||
|
||||
// Error events - convert to empty chunks (errors should be handled elsewhere)
|
||||
ConverseStreamEvent::InternalServerException(_)
|
||||
| ConverseStreamEvent::ModelStreamErrorException(_)
|
||||
| ConverseStreamEvent::ServiceUnavailableException(_)
|
||||
| ConverseStreamEvent::ThrottlingException(_)
|
||||
| ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert content block start to OpenAI chunk
|
||||
fn convert_content_block_start(
|
||||
content_block: MessagesContentBlock,
|
||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
match content_block {
|
||||
MessagesContentBlock::Text { .. } => {
|
||||
// No immediate output for text block start
|
||||
Ok(create_empty_openai_chunk())
|
||||
}
|
||||
MessagesContentBlock::ToolUse { id, name, .. }
|
||||
| MessagesContentBlock::ServerToolUse { id, name, .. }
|
||||
| MessagesContentBlock::McpToolUse { id, name, .. } => {
|
||||
// Tool use start → OpenAI chunk with tool_calls
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: Some(id),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some(name),
|
||||
arguments: Some("".to_string()),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
_ => Err(TransformError::UnsupportedContent(
|
||||
"Unsupported content block type in stream start".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert content delta to OpenAI chunk
|
||||
fn convert_content_delta(
|
||||
delta: MessagesContentDelta,
|
||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(format!("thinking: {}", thinking)),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(partial_json),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create OpenAI streaming chunk
|
||||
fn create_openai_chunk(
|
||||
id: &str,
|
||||
model: &str,
|
||||
delta: MessageDelta,
|
||||
finish_reason: Option<FinishReason>,
|
||||
usage: Option<Usage>,
|
||||
) -> ChatCompletionsStreamResponse {
|
||||
ChatCompletionsStreamResponse {
|
||||
id: id.to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: current_timestamp(),
|
||||
model: model.to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta,
|
||||
finish_reason,
|
||||
logprobs: None,
|
||||
}],
|
||||
usage,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create empty OpenAI streaming chunk
|
||||
fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
|
||||
create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
/// Convert Anthropic content blocks to OpenAI message content
|
||||
fn convert_anthropic_content_to_openai(
|
||||
content: &[MessagesContentBlock],
|
||||
) -> Result<MessageContent, TransformError> {
|
||||
let mut text_parts = Vec::new();
|
||||
|
||||
for block in content {
|
||||
match block {
|
||||
MessagesContentBlock::Text { text, .. } => {
|
||||
text_parts.push(text.clone());
|
||||
}
|
||||
MessagesContentBlock::Thinking { thinking, .. } => {
|
||||
text_parts.push(format!("thinking: {}", thinking));
|
||||
}
|
||||
_ => {
|
||||
// Skip other content types for basic text conversion
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(MessageContent::Text(text_parts.join("\n")))
|
||||
}
|
||||
|
||||
// Stop Reason Conversions
|
||||
impl Into<FinishReason> for MessagesStopReason {
|
||||
fn into(self) -> FinishReason {
|
||||
match self {
|
||||
MessagesStopReason::EndTurn => FinishReason::Stop,
|
||||
MessagesStopReason::MaxTokens => FinishReason::Length,
|
||||
MessagesStopReason::StopSequence => FinishReason::Stop,
|
||||
MessagesStopReason::ToolUse => FinishReason::ToolCalls,
|
||||
MessagesStopReason::PauseTurn => FinishReason::Stop,
|
||||
MessagesStopReason::Refusal => FinishReason::ContentFilter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert Bedrock Message to OpenAI content and tool calls
|
||||
/// This function extracts text content and tool calls from a Bedrock message
|
||||
fn convert_bedrock_message_to_openai(
|
||||
|
|
@ -627,6 +372,31 @@ fn convert_bedrock_message_to_openai(
|
|||
Ok((content, tool_calls))
|
||||
}
|
||||
|
||||
/// Convert Anthropic content blocks to OpenAI message content
|
||||
fn convert_anthropic_content_to_openai(
|
||||
content: &[MessagesContentBlock],
|
||||
) -> Result<MessageContent, TransformError> {
|
||||
let mut text_parts = Vec::new();
|
||||
|
||||
for block in content {
|
||||
match block {
|
||||
MessagesContentBlock::Text { text, .. } => {
|
||||
text_parts.push(text.clone());
|
||||
}
|
||||
MessagesContentBlock::Thinking { thinking, .. } => {
|
||||
text_parts.push(format!("thinking: {}", thinking));
|
||||
}
|
||||
_ => {
|
||||
// Skip other content types for basic text conversion
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(MessageContent::Text(text_parts.join("\n")))
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -1166,4 +936,212 @@ mod tests {
|
|||
assert!(content.contains("Here's the analysis:"));
|
||||
// Note: Image blocks are not converted to text in the current implementation
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_to_responses_api_basic() {
|
||||
use crate::apis::openai_responses::{OutputContent, OutputItem, ResponsesAPIResponse};
|
||||
|
||||
let chat_response = ChatCompletionsResponse {
|
||||
id: "chatcmpl-123".to_string(),
|
||||
object: Some("chat.completion".to_string()),
|
||||
created: 1677652288,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: crate::apis::openai::ResponseMessage {
|
||||
role: Role::Assistant,
|
||||
content: Some("Hello! How can I help you?".to_string()),
|
||||
refusal: None,
|
||||
annotations: None,
|
||||
audio: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: Some(FinishReason::Stop),
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 20,
|
||||
total_tokens: 30,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
},
|
||||
system_fingerprint: None,
|
||||
service_tier: Some("default".to_string()),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
|
||||
|
||||
assert_eq!(responses_api.id, "chatcmpl-123");
|
||||
assert_eq!(responses_api.object, "response");
|
||||
assert_eq!(responses_api.model, "gpt-4");
|
||||
|
||||
// Check usage conversion
|
||||
let usage = responses_api.usage.unwrap();
|
||||
assert_eq!(usage.input_tokens, 10);
|
||||
assert_eq!(usage.output_tokens, 20);
|
||||
assert_eq!(usage.total_tokens, 30);
|
||||
|
||||
// Check output items
|
||||
assert_eq!(responses_api.output.len(), 1);
|
||||
match &responses_api.output[0] {
|
||||
OutputItem::Message {
|
||||
role,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(role, "assistant");
|
||||
assert_eq!(content.len(), 1);
|
||||
match &content[0] {
|
||||
OutputContent::OutputText { text, .. } => {
|
||||
assert_eq!(text, "Hello! How can I help you?");
|
||||
}
|
||||
_ => panic!("Expected OutputText content"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected Message output item"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_to_responses_api_with_tool_calls() {
|
||||
use crate::apis::openai::{FunctionCall, ToolCall};
|
||||
use crate::apis::openai_responses::{OutputItem, ResponsesAPIResponse};
|
||||
|
||||
let chat_response = ChatCompletionsResponse {
|
||||
id: "chatcmpl-456".to_string(),
|
||||
object: Some("chat.completion".to_string()),
|
||||
created: 1677652300,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: crate::apis::openai::ResponseMessage {
|
||||
role: Role::Assistant,
|
||||
content: Some("Let me check the weather.".to_string()),
|
||||
refusal: None,
|
||||
annotations: None,
|
||||
audio: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: "call_abc123".to_string(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "get_weather".to_string(),
|
||||
arguments: r#"{"location":"San Francisco"}"#.to_string(),
|
||||
},
|
||||
}]),
|
||||
},
|
||||
finish_reason: Some(FinishReason::ToolCalls),
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: 15,
|
||||
completion_tokens: 25,
|
||||
total_tokens: 40,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
},
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
|
||||
|
||||
// Should have 2 output items: message + function call
|
||||
assert_eq!(responses_api.output.len(), 2);
|
||||
|
||||
// Check message item
|
||||
match &responses_api.output[0] {
|
||||
OutputItem::Message { content, .. } => {
|
||||
assert_eq!(content.len(), 1);
|
||||
}
|
||||
_ => panic!("Expected Message output item"),
|
||||
}
|
||||
|
||||
// Check function call item
|
||||
match &responses_api.output[1] {
|
||||
OutputItem::FunctionCall {
|
||||
call_id,
|
||||
name,
|
||||
arguments,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(call_id, "call_abc123");
|
||||
assert_eq!(name.as_ref().unwrap(), "get_weather");
|
||||
assert!(arguments.as_ref().unwrap().contains("San Francisco"));
|
||||
}
|
||||
_ => panic!("Expected FunctionCall output item"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_to_responses_api_tool_calls_only() {
|
||||
use crate::apis::openai::{FunctionCall, ToolCall};
|
||||
use crate::apis::openai_responses::{OutputItem, ResponsesAPIResponse};
|
||||
|
||||
// Test the real-world case where content is null and there are only tool calls
|
||||
let chat_response = ChatCompletionsResponse {
|
||||
id: "chatcmpl-789".to_string(),
|
||||
object: Some("chat.completion".to_string()),
|
||||
created: 1764023939,
|
||||
model: "gpt-4o-2024-08-06".to_string(),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: crate::apis::openai::ResponseMessage {
|
||||
role: Role::Assistant,
|
||||
content: None, // No text content, only tool calls
|
||||
refusal: None,
|
||||
annotations: None,
|
||||
audio: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: "call_oJBtqTJmRfBGlFS55QhMfUUV".to_string(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "get_weather".to_string(),
|
||||
arguments: r#"{"location":"San Francisco, CA"}"#.to_string(),
|
||||
},
|
||||
}]),
|
||||
},
|
||||
finish_reason: Some(FinishReason::ToolCalls),
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: 84,
|
||||
completion_tokens: 17,
|
||||
total_tokens: 101,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
},
|
||||
system_fingerprint: Some("fp_7eeb46f068".to_string()),
|
||||
service_tier: Some("default".to_string()),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
|
||||
|
||||
// Should have only 1 output item: function call (no empty message item)
|
||||
assert_eq!(responses_api.output.len(), 1);
|
||||
|
||||
// Check function call item
|
||||
match &responses_api.output[0] {
|
||||
OutputItem::FunctionCall {
|
||||
call_id,
|
||||
name,
|
||||
arguments,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(call_id, "call_oJBtqTJmRfBGlFS55QhMfUUV");
|
||||
assert_eq!(name.as_ref().unwrap(), "get_weather");
|
||||
assert!(arguments.as_ref().unwrap().contains("San Francisco, CA"));
|
||||
}
|
||||
_ => panic!("Expected FunctionCall output item as first item"),
|
||||
}
|
||||
|
||||
// Verify status is Completed for tool_calls finish reason
|
||||
assert!(matches!(responses_api.status, crate::apis::openai_responses::ResponseStatus::Completed));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
pub mod to_anthropic_streaming;
|
||||
pub mod to_openai_streaming;
|
||||
|
|
@ -0,0 +1,281 @@
|
|||
use crate::apis::amazon_bedrock::{
|
||||
ContentBlockDelta, ConverseStreamEvent,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta,
|
||||
MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
|
||||
};
|
||||
use crate::apis::openai::{ ChatCompletionsStreamResponse, ToolCallDelta,
|
||||
};
|
||||
use crate::clients::TransformError;
|
||||
use serde_json::Value;
|
||||
|
||||
impl TryFrom<ChatCompletionsStreamResponse> for MessagesStreamEvent {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(resp: ChatCompletionsStreamResponse) -> Result<Self, Self::Error> {
|
||||
if resp.choices.is_empty() {
|
||||
return Ok(MessagesStreamEvent::Ping);
|
||||
}
|
||||
|
||||
let choice = &resp.choices[0];
|
||||
|
||||
// Handle final chunk with usage
|
||||
let has_usage = resp.usage.is_some();
|
||||
if let Some(usage) = resp.usage {
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into();
|
||||
return Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_stop_reason,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: usage.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: We do NOT emit MessageStart here anymore!
|
||||
// The AnthropicMessagesStreamBuffer will inject message_start and content_block_start
|
||||
// when it sees the first content_block_delta. This solves the problem where OpenAI
|
||||
// sends both role and content in the same chunk - we can only return one event here,
|
||||
// so we prioritize the content and let the buffer handle lifecycle events.
|
||||
|
||||
// Handle content delta (even if role is present in the same chunk)
|
||||
if let Some(content) = &choice.delta.content {
|
||||
if !content.is_empty() {
|
||||
return Ok(MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::TextDelta {
|
||||
text: content.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if let Some(tool_calls) = &choice.delta.tool_calls {
|
||||
return convert_tool_call_deltas(tool_calls.clone());
|
||||
}
|
||||
|
||||
// Handle finish reason - generate MessageDelta only (MessageStop comes later)
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
// If we have usage data, it was already handled above
|
||||
// If not, we need to generate MessageDelta with default usage
|
||||
if !has_usage {
|
||||
let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into();
|
||||
return Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_stop_reason,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
});
|
||||
}
|
||||
// If usage was already handled above, we don't need to do anything more here
|
||||
// MessageStop will be handled when [DONE] is encountered
|
||||
}
|
||||
|
||||
// Default to ping for unhandled cases
|
||||
Ok(MessagesStreamEvent::Ping)
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<String> for MessagesStreamEvent {
|
||||
fn into(self) -> String {
|
||||
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
|
||||
let event_type = match &self {
|
||||
MessagesStreamEvent::MessageStart { .. } => "message_start",
|
||||
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
|
||||
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop",
|
||||
MessagesStreamEvent::MessageDelta { .. } => "message_delta",
|
||||
MessagesStreamEvent::MessageStop => "message_stop",
|
||||
MessagesStreamEvent::Ping => "ping",
|
||||
};
|
||||
|
||||
let event = format!("event: {}\n", event_type);
|
||||
let data = format!("data: {}\n\n", transformed_json);
|
||||
event + &data
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(event: ConverseStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
// MessageStart - convert to Anthropic MessageStart
|
||||
ConverseStreamEvent::MessageStart(start_event) => {
|
||||
let role = match start_event.role {
|
||||
crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User,
|
||||
crate::apis::amazon_bedrock::ConversationRole::Assistant => {
|
||||
MessagesRole::Assistant
|
||||
}
|
||||
};
|
||||
|
||||
Ok(MessagesStreamEvent::MessageStart {
|
||||
message: MessagesStreamMessage {
|
||||
id: format!(
|
||||
"bedrock-stream-{}",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos()
|
||||
),
|
||||
obj_type: "message".to_string(),
|
||||
role,
|
||||
content: vec![],
|
||||
model: "bedrock-model".to_string(),
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ContentBlockStart - convert to Anthropic ContentBlockStart
|
||||
ConverseStreamEvent::ContentBlockStart(start_event) => {
|
||||
// Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas
|
||||
// Anthropic expects the same pattern, so we initialize with an empty input object
|
||||
match start_event.start {
|
||||
crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => {
|
||||
Ok(MessagesStreamEvent::ContentBlockStart {
|
||||
index: start_event.content_block_index as u32,
|
||||
content_block: MessagesContentBlock::ToolUse {
|
||||
id: tool_use.tool_use_id,
|
||||
name: tool_use.name,
|
||||
input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas
|
||||
cache_control: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ContentBlockDelta - convert to Anthropic ContentBlockDelta
|
||||
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
|
||||
let delta = match delta_event.delta {
|
||||
ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text },
|
||||
ContentBlockDelta::ToolUse { tool_use } => {
|
||||
MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: tool_use.input,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(MessagesStreamEvent::ContentBlockDelta {
|
||||
index: delta_event.content_block_index as u32,
|
||||
delta,
|
||||
})
|
||||
}
|
||||
|
||||
// ContentBlockStop - convert to Anthropic ContentBlockStop
|
||||
ConverseStreamEvent::ContentBlockStop(stop_event) => {
|
||||
Ok(MessagesStreamEvent::ContentBlockStop {
|
||||
index: stop_event.content_block_index as u32,
|
||||
})
|
||||
}
|
||||
|
||||
// MessageStop - convert to Anthropic MessageDelta with stop reason
|
||||
// Note: Bedrock sends Metadata separately with usage info, creating a second MessageDelta
|
||||
// The client should merge these or use the final one with complete usage
|
||||
ConverseStreamEvent::MessageStop(stop_event) => {
|
||||
let anthropic_stop_reason = match stop_event.stop_reason {
|
||||
crate::apis::amazon_bedrock::StopReason::EndTurn => MessagesStopReason::EndTurn,
|
||||
crate::apis::amazon_bedrock::StopReason::ToolUse => MessagesStopReason::ToolUse,
|
||||
crate::apis::amazon_bedrock::StopReason::MaxTokens => MessagesStopReason::MaxTokens,
|
||||
crate::apis::amazon_bedrock::StopReason::StopSequence => MessagesStopReason::EndTurn,
|
||||
crate::apis::amazon_bedrock::StopReason::GuardrailIntervened => MessagesStopReason::Refusal,
|
||||
crate::apis::amazon_bedrock::StopReason::ContentFiltered => MessagesStopReason::Refusal,
|
||||
};
|
||||
|
||||
Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_stop_reason,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Metadata - convert usage information to MessageDelta
|
||||
ConverseStreamEvent::Metadata(metadata_event) => {
|
||||
Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: MessagesStopReason::EndTurn,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: metadata_event.usage.input_tokens,
|
||||
output_tokens: metadata_event.usage.output_tokens,
|
||||
cache_creation_input_tokens: metadata_event.usage.cache_write_input_tokens,
|
||||
cache_read_input_tokens: metadata_event.usage.cache_read_input_tokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Exception events - convert to Ping (could be enhanced to return error events)
|
||||
ConverseStreamEvent::InternalServerException(_)
|
||||
| ConverseStreamEvent::ModelStreamErrorException(_)
|
||||
| ConverseStreamEvent::ServiceUnavailableException(_)
|
||||
| ConverseStreamEvent::ThrottlingException(_)
|
||||
| ConverseStreamEvent::ValidationException(_) => {
|
||||
// TODO: Consider adding proper error handling/events
|
||||
Ok(MessagesStreamEvent::Ping)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert tool call deltas to Anthropic stream events
|
||||
fn convert_tool_call_deltas(
|
||||
tool_calls: Vec<ToolCallDelta>,
|
||||
) -> Result<MessagesStreamEvent, TransformError> {
|
||||
for tool_call in tool_calls {
|
||||
if let Some(id) = &tool_call.id {
|
||||
// Tool call start
|
||||
if let Some(function) = &tool_call.function {
|
||||
if let Some(name) = &function.name {
|
||||
return Ok(MessagesStreamEvent::ContentBlockStart {
|
||||
index: tool_call.index,
|
||||
content_block: MessagesContentBlock::ToolUse {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
input: Value::Object(serde_json::Map::new()),
|
||||
cache_control: None,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if let Some(function) = &tool_call.function {
|
||||
if let Some(arguments) = &function.arguments {
|
||||
// Tool arguments delta
|
||||
return Ok(MessagesStreamEvent::ContentBlockDelta {
|
||||
index: tool_call.index,
|
||||
delta: MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: arguments.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to ping if no valid tool call found
|
||||
Ok(MessagesStreamEvent::Ping)
|
||||
}
|
||||
|
|
@ -0,0 +1,527 @@
|
|||
use crate::apis::amazon_bedrock::{ ConverseStreamEvent, StopReason};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesStreamEvent};
|
||||
use crate::apis::openai::{ ChatCompletionsStreamResponse,FinishReason,
|
||||
FunctionCallDelta, MessageDelta, Role, StreamChoice, ToolCallDelta, Usage,
|
||||
};
|
||||
use crate::apis::openai_responses::ResponsesAPIStreamEvent;
|
||||
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::*;
|
||||
|
||||
// ============================================================================
|
||||
// PROVIDER STREAMING TRANSFORMATIONS TO OPENAI FORMAT
|
||||
// ============================================================================
|
||||
//
|
||||
// This module handles business logic for converting streaming events from
|
||||
// various providers (Anthropic, Bedrock, etc.) into OpenAI's ChatCompletions format.
|
||||
//
|
||||
// # Architecture Separation
|
||||
//
|
||||
// **Provider Transformations** (this module):
|
||||
// - Business logic for converting between provider formats
|
||||
// - Uses Rust traits (TryFrom, Into) for type-safe conversions
|
||||
// - Stateless event-by-event transformation
|
||||
// - Example: MessagesStreamEvent → ChatCompletionsStreamResponse
|
||||
//
|
||||
// **Wire Format Buffering** (`apis/streaming_shapes/`):
|
||||
// - SSE protocol handling (data:, event: lines)
|
||||
// - State accumulation and lifecycle management
|
||||
// - Buffering for stateful APIs (v1/responses)
|
||||
// - Example: ChatCompletionsToResponsesTransformer
|
||||
//
|
||||
// # Flow
|
||||
//
|
||||
// ```text
|
||||
// Anthropic Event → [Provider Transform] → OpenAI Event → [Wire Buffer] → SSE Wire Format
|
||||
// (business) (this module) (protocol) (streaming_shapes) (network)
|
||||
// ```
|
||||
//
|
||||
// ============================================================================
|
||||
|
||||
impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(event: MessagesStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk(
|
||||
&message.id,
|
||||
&message.model,
|
||||
MessageDelta {
|
||||
role: Some(Role::Assistant),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
|
||||
MessagesStreamEvent::ContentBlockStart { content_block, .. } => {
|
||||
convert_content_block_start(content_block)
|
||||
}
|
||||
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta),
|
||||
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
|
||||
|
||||
MessagesStreamEvent::MessageDelta { delta, usage } => {
|
||||
let finish_reason: Option<FinishReason> = Some(delta.stop_reason.into());
|
||||
let openai_usage: Option<Usage> = Some(usage.into());
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason,
|
||||
openai_usage,
|
||||
))
|
||||
}
|
||||
|
||||
MessagesStreamEvent::MessageStop => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Some(FinishReason::Stop),
|
||||
None,
|
||||
)),
|
||||
|
||||
MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse {
|
||||
id: "stream".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: current_timestamp(),
|
||||
model: "unknown".to_string(),
|
||||
choices: vec![],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(event: ConverseStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
ConverseStreamEvent::MessageStart(start_event) => {
|
||||
let role = match start_event.role {
|
||||
crate::apis::amazon_bedrock::ConversationRole::User => Role::User,
|
||||
crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant,
|
||||
};
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: Some(role),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
ConverseStreamEvent::ContentBlockStart(start_event) => {
|
||||
use crate::apis::amazon_bedrock::ContentBlockStart;
|
||||
|
||||
match start_event.start {
|
||||
ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: start_event.content_block_index as u32,
|
||||
id: Some(tool_use.tool_use_id),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some(tool_use.name),
|
||||
arguments: Some("".to_string()),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
|
||||
use crate::apis::amazon_bedrock::ContentBlockDelta;
|
||||
|
||||
match delta_event.delta {
|
||||
ContentBlockDelta::Text { text } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: delta_event.content_block_index as u32,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(tool_use.input),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()),
|
||||
|
||||
ConverseStreamEvent::MessageStop(stop_event) => {
|
||||
let finish_reason = match stop_event.stop_reason {
|
||||
StopReason::EndTurn => FinishReason::Stop,
|
||||
StopReason::ToolUse => FinishReason::ToolCalls,
|
||||
StopReason::MaxTokens => FinishReason::Length,
|
||||
StopReason::StopSequence => FinishReason::Stop,
|
||||
StopReason::GuardrailIntervened => FinishReason::ContentFilter,
|
||||
StopReason::ContentFiltered => FinishReason::ContentFilter,
|
||||
};
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Some(finish_reason),
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
ConverseStreamEvent::Metadata(metadata_event) => {
|
||||
let usage = Usage {
|
||||
prompt_tokens: metadata_event.usage.input_tokens,
|
||||
completion_tokens: metadata_event.usage.output_tokens,
|
||||
total_tokens: metadata_event.usage.total_tokens,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
};
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
Some(usage),
|
||||
))
|
||||
}
|
||||
|
||||
// Error events - convert to empty chunks (errors should be handled elsewhere)
|
||||
ConverseStreamEvent::InternalServerException(_)
|
||||
| ConverseStreamEvent::ModelStreamErrorException(_)
|
||||
| ConverseStreamEvent::ServiceUnavailableException(_)
|
||||
| ConverseStreamEvent::ThrottlingException(_)
|
||||
| ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert content block start to OpenAI chunk
|
||||
fn convert_content_block_start(
|
||||
content_block: MessagesContentBlock,
|
||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
match content_block {
|
||||
MessagesContentBlock::Text { .. } => {
|
||||
// No immediate output for text block start
|
||||
Ok(create_empty_openai_chunk())
|
||||
}
|
||||
MessagesContentBlock::ToolUse { id, name, .. }
|
||||
| MessagesContentBlock::ServerToolUse { id, name, .. }
|
||||
| MessagesContentBlock::McpToolUse { id, name, .. } => {
|
||||
// Tool use start → OpenAI chunk with tool_calls
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: Some(id),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some(name),
|
||||
arguments: Some("".to_string()),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
_ => Err(TransformError::UnsupportedContent(
|
||||
"Unsupported content block type in stream start".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert content delta to OpenAI chunk
|
||||
fn convert_content_delta(
|
||||
delta: MessagesContentDelta,
|
||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(format!("thinking: {}", thinking)),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(partial_json),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create OpenAI streaming chunk
|
||||
fn create_openai_chunk(
|
||||
id: &str,
|
||||
model: &str,
|
||||
delta: MessageDelta,
|
||||
finish_reason: Option<FinishReason>,
|
||||
usage: Option<Usage>,
|
||||
) -> ChatCompletionsStreamResponse {
|
||||
ChatCompletionsStreamResponse {
|
||||
id: id.to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: current_timestamp(),
|
||||
model: model.to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta,
|
||||
finish_reason,
|
||||
logprobs: None,
|
||||
}],
|
||||
usage,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create empty OpenAI streaming chunk
|
||||
fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
|
||||
create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
// Stop Reason Conversions
|
||||
impl Into<FinishReason> for MessagesStopReason {
|
||||
fn into(self) -> FinishReason {
|
||||
match self {
|
||||
MessagesStopReason::EndTurn => FinishReason::Stop,
|
||||
MessagesStopReason::MaxTokens => FinishReason::Length,
|
||||
MessagesStopReason::StopSequence => FinishReason::Stop,
|
||||
MessagesStopReason::ToolUse => FinishReason::ToolCalls,
|
||||
MessagesStopReason::PauseTurn => FinishReason::Stop,
|
||||
MessagesStopReason::Refusal => FinishReason::ContentFilter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ChatCompletionsStreamResponse> for ResponsesAPIStreamEvent {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(chunk: ChatCompletionsStreamResponse) -> Result<Self, TransformError> {
|
||||
// Stateless conversion - just extract the delta information
|
||||
// The buffer will manage state, item IDs, and sequence numbers
|
||||
|
||||
// Extract first choice if available
|
||||
if let Some(choice) = chunk.choices.first() {
|
||||
let delta = &choice.delta;
|
||||
|
||||
// Tool call with function name and/or arguments
|
||||
if let Some(tool_calls) = &delta.tool_calls {
|
||||
if let Some(tool_call) = tool_calls.first() {
|
||||
// Extract call_id and name if available (metadata from initial event)
|
||||
let call_id = tool_call.id.clone();
|
||||
let function_name = tool_call.function.as_ref()
|
||||
.and_then(|f| f.name.clone());
|
||||
|
||||
// Check if we have function metadata (name, id)
|
||||
if let Some(function) = &tool_call.function {
|
||||
// If we have arguments delta, return that
|
||||
if let Some(args) = &function.arguments {
|
||||
return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
|
||||
output_index: choice.index as i32,
|
||||
item_id: "".to_string(), // Buffer will fill this
|
||||
delta: args.clone(),
|
||||
sequence_number: 0, // Buffer will fill this
|
||||
call_id,
|
||||
name: function_name,
|
||||
});
|
||||
}
|
||||
|
||||
// If we have function name but no arguments yet (initial tool call event)
|
||||
// Return an empty arguments delta so the buffer knows to create the item
|
||||
if function.name.is_some() {
|
||||
return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
|
||||
output_index: choice.index as i32,
|
||||
item_id: "".to_string(), // Buffer will fill this
|
||||
delta: "".to_string(), // Empty delta signals this is the initial event
|
||||
sequence_number: 0, // Buffer will fill this
|
||||
call_id,
|
||||
name: function_name,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Text content delta
|
||||
if let Some(content) = &delta.content {
|
||||
if !content.is_empty() {
|
||||
return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta {
|
||||
item_id: "".to_string(), // Buffer will fill this
|
||||
output_index: choice.index as i32,
|
||||
content_index: 0,
|
||||
delta: content.clone(),
|
||||
logprobs: vec![],
|
||||
obfuscation: None,
|
||||
sequence_number: 0, // Buffer will fill this
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle finish_reason - this is a completion signal
|
||||
// Return an empty delta that the buffer can use to detect completion
|
||||
if choice.finish_reason.is_some() {
|
||||
// Return a minimal text delta to signal completion
|
||||
// The buffer will handle the finish_reason and generate response.completed
|
||||
return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta {
|
||||
item_id: "".to_string(), // Buffer will fill this
|
||||
output_index: choice.index as i32,
|
||||
content_index: 0,
|
||||
delta: "".to_string(), // Empty delta signals completion
|
||||
logprobs: vec![],
|
||||
obfuscation: None,
|
||||
sequence_number: 0, // Buffer will fill this
|
||||
});
|
||||
}
|
||||
|
||||
// Empty delta with role only (common at stream start)
|
||||
if delta.role.is_some() {
|
||||
// This is typically the first chunk establishing the assistant role
|
||||
// Return an empty text delta that the buffer can use to initialize state
|
||||
return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta {
|
||||
item_id: "".to_string(),
|
||||
output_index: choice.index as i32,
|
||||
content_index: 0,
|
||||
delta: "".to_string(),
|
||||
logprobs: vec![],
|
||||
obfuscation: None,
|
||||
sequence_number: 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Empty chunk or no convertible content (e.g., keep-alive chunks with delta: {})
|
||||
// These are valid in OpenAI streaming and should be silently ignored
|
||||
// Return error so the caller can skip these chunks without warnings
|
||||
Err(TransformError::UnsupportedConversion(
|
||||
"Empty or keep-alive chunk with no convertible content".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
@ -22,11 +22,13 @@ use common::ratelimit::Header;
|
|||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
use hermesllm::apis::anthropic::{MessagesContentBlock, MessagesStreamEvent};
|
||||
use hermesllm::apis::sse::{SseEvent, SseStreamIter};
|
||||
use hermesllm::clients::endpoints::SupportedAPIs;
|
||||
use hermesllm::apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
use hermesllm::apis::streaming_shapes::sse::{
|
||||
SseEvent, SseStreamBuffer, SseStreamBufferTrait, SseStreamIter,
|
||||
};
|
||||
use hermesllm::clients::endpoints::SupportedAPIsFromClient;
|
||||
use hermesllm::providers::response::ProviderResponse;
|
||||
use hermesllm::providers::streaming_response::ProviderStreamResponse;
|
||||
use hermesllm::{
|
||||
DecodedFrame, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType,
|
||||
ProviderStreamResponseType,
|
||||
|
|
@ -38,7 +40,7 @@ pub struct StreamContext {
|
|||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
/// The API that is requested by the client (before compatibility mapping)
|
||||
client_api: Option<SupportedAPIs>,
|
||||
client_api: Option<SupportedAPIsFromClient>,
|
||||
/// The API that should be used for the upstream provider (after compatibility mapping)
|
||||
resolved_api: Option<SupportedUpstreamAPIs>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
|
|
@ -56,6 +58,7 @@ pub struct StreamContext {
|
|||
binary_frame_decoder: Option<BedrockBinaryFrameDecoder<bytes::BytesMut>>,
|
||||
http_method: Option<String>,
|
||||
http_protocol: Option<String>,
|
||||
sse_buffer: Option<SseStreamBuffer>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -87,6 +90,7 @@ impl StreamContext {
|
|||
binary_frame_decoder: None,
|
||||
http_method: None,
|
||||
http_protocol: None,
|
||||
sse_buffer: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -172,7 +176,8 @@ impl StreamContext {
|
|||
Some(
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_)
|
||||
| SupportedUpstreamAPIs::AmazonBedrockConverse(_)
|
||||
| SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
| SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)
|
||||
| SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
)
|
||||
| None => {
|
||||
// OpenAI and default: use Authorization Bearer token
|
||||
|
|
@ -476,7 +481,17 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let mut response_buffer = Vec::new();
|
||||
// Initialize SSE buffer if not present
|
||||
if self.sse_buffer.is_none() {
|
||||
self.sse_buffer = match SseStreamBuffer::try_from((&client_api, &upstream_api))
|
||||
{
|
||||
Ok(buffer) => Some(buffer),
|
||||
Err(e) => {
|
||||
warn!("Failed to create SSE buffer: {}", e);
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Process each SSE event
|
||||
for sse_event in sse_iter {
|
||||
|
|
@ -527,12 +542,32 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
// Add transformed event to response buffer
|
||||
let bytes: Vec<u8> = transformed_event.into();
|
||||
response_buffer.extend_from_slice(&bytes);
|
||||
// Add transformed event to buffer (buffer may inject lifecycle events)
|
||||
if let Some(buffer) = self.sse_buffer.as_mut() {
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response_buffer)
|
||||
// Get accumulated bytes from buffer and return
|
||||
match self.sse_buffer.as_mut() {
|
||||
Some(buffer) => {
|
||||
let bytes = buffer.into_bytes();
|
||||
if !bytes.is_empty() {
|
||||
let content = String::from_utf8_lossy(&bytes);
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_TRANSFORMED_CLIENT_RESPONSE: size={} content={}",
|
||||
self.request_identifier(),
|
||||
bytes.len(),
|
||||
content
|
||||
);
|
||||
}
|
||||
Ok(bytes)
|
||||
}
|
||||
None => {
|
||||
warn!("SSE buffer unexpectedly missing after initialization");
|
||||
Err(Action::Continue)
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!("Missing client_api for non-streaming response");
|
||||
|
|
@ -544,7 +579,7 @@ impl StreamContext {
|
|||
fn handle_bedrock_binary_stream(
|
||||
&mut self,
|
||||
body: &[u8],
|
||||
client_api: &SupportedAPIs,
|
||||
client_api: &SupportedAPIsFromClient,
|
||||
upstream_api: &SupportedUpstreamAPIs,
|
||||
) -> Result<Vec<u8>, Action> {
|
||||
// Initialize decoder if not present
|
||||
|
|
@ -552,83 +587,57 @@ impl StreamContext {
|
|||
self.binary_frame_decoder = Some(BedrockBinaryFrameDecoder::from_bytes(&[]));
|
||||
}
|
||||
|
||||
// Add incoming bytes to buffer
|
||||
// Initialize SSE buffer if not present
|
||||
if self.sse_buffer.is_none() {
|
||||
self.sse_buffer = match SseStreamBuffer::try_from((client_api, upstream_api)) {
|
||||
Ok(buffer) => Some(buffer),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_BUFFER_INIT_ERROR: {}",
|
||||
self.request_identifier(),
|
||||
e
|
||||
);
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Add incoming bytes to decoder buffer
|
||||
let decoder = self.binary_frame_decoder.as_mut().unwrap();
|
||||
decoder.buffer_mut().extend_from_slice(body);
|
||||
|
||||
let mut response_buffer = Vec::new();
|
||||
// Process all complete frames
|
||||
loop {
|
||||
let decoded_frame = self.binary_frame_decoder.as_mut().unwrap().decode_frame();
|
||||
match decoded_frame {
|
||||
Some(DecodedFrame::Complete(ref frame_ref)) => {
|
||||
let frame = DecodedFrame::Complete(frame_ref.clone());
|
||||
|
||||
// Convert frame to provider response type
|
||||
match ProviderStreamResponseType::try_from((&frame, client_api, upstream_api)) {
|
||||
Ok(provider_response) => {
|
||||
self.record_ttft_if_needed();
|
||||
|
||||
// Handle ContentBlockStart and ContentBlockDelta events
|
||||
match &provider_response {
|
||||
ProviderStreamResponseType::MessagesStreamEvent(evt) => {
|
||||
match evt {
|
||||
MessagesStreamEvent::ContentBlockStart {
|
||||
index, ..
|
||||
} => {
|
||||
// Mark that we've seen ContentBlockStart for this index
|
||||
self.binary_frame_decoder
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.set_content_block_start_sent(*index as i32);
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_CONTENT_BLOCK_START_TRACKED: index={}",
|
||||
self.request_identifier(),
|
||||
*index
|
||||
);
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockDelta {
|
||||
index, ..
|
||||
} => {
|
||||
// Check if ContentBlockStart was sent for this index
|
||||
let needs_start = !self
|
||||
.binary_frame_decoder
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.has_content_block_start_been_sent(*index as i32);
|
||||
|
||||
if needs_start {
|
||||
// Emit empty ContentBlockStart before delta
|
||||
let content_block_start =
|
||||
MessagesStreamEvent::ContentBlockStart {
|
||||
index: *index,
|
||||
content_block: MessagesContentBlock::Text {
|
||||
text: String::new(),
|
||||
cache_control: None,
|
||||
},
|
||||
};
|
||||
let start_sse: String = content_block_start.into();
|
||||
response_buffer
|
||||
.extend_from_slice(start_sse.as_bytes());
|
||||
|
||||
// Mark that we've now sent it
|
||||
self.binary_frame_decoder
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.set_content_block_start_sent(*index as i32);
|
||||
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_INJECTED_CONTENT_BLOCK_START: index={}",
|
||||
self.request_identifier(),
|
||||
*index
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
// Track token usage
|
||||
if let Some(content) = provider_response.content_delta() {
|
||||
let estimated_tokens = content.len() / 4;
|
||||
self.response_tokens += estimated_tokens.max(1);
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_TOKEN_UPDATE: delta_chars={} estimated_tokens={} total_tokens={}",
|
||||
self.request_identifier(),
|
||||
content.len(),
|
||||
estimated_tokens.max(1),
|
||||
self.response_tokens
|
||||
);
|
||||
}
|
||||
|
||||
let sse_string: String = provider_response.into();
|
||||
response_buffer.extend_from_slice(sse_string.as_bytes());
|
||||
// Create SseEvent from provider response
|
||||
let event = SseEvent::from_provider_response(provider_response);
|
||||
|
||||
// Add to buffer (buffer handles all shim logic including ContentBlockStart injection)
|
||||
if let Some(buffer) = self.sse_buffer.as_mut() {
|
||||
buffer.add_transformed_event(event);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
|
|
@ -658,8 +667,29 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
// Return accumulated complete frames (may be empty if all frames incomplete)
|
||||
Ok(response_buffer)
|
||||
// Get accumulated bytes from buffer and return
|
||||
match self.sse_buffer.as_mut() {
|
||||
Some(buffer) => {
|
||||
let bytes = buffer.into_bytes();
|
||||
if !bytes.is_empty() {
|
||||
let content = String::from_utf8_lossy(&bytes);
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_TRANSFORMED_CLIENT_RESPONSE: size={} content={}",
|
||||
self.request_identifier(),
|
||||
bytes.len(),
|
||||
content
|
||||
);
|
||||
}
|
||||
Ok(bytes)
|
||||
}
|
||||
None => {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_BUFFER_MISSING",
|
||||
self.request_identifier()
|
||||
);
|
||||
Err(Action::Continue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_non_streaming_response(
|
||||
|
|
@ -782,13 +812,14 @@ impl HttpContext for StreamContext {
|
|||
self.select_llm_provider();
|
||||
|
||||
// Check if this is a supported API endpoint
|
||||
if SupportedAPIs::from_endpoint(&request_path).is_none() {
|
||||
if SupportedAPIsFromClient::from_endpoint(&request_path).is_none() {
|
||||
self.send_http_response(404, vec![], Some(b"Unsupported endpoint"));
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
// Get the SupportedApi for routing decisions
|
||||
let supported_api: Option<SupportedAPIs> = SupportedAPIs::from_endpoint(&request_path);
|
||||
let supported_api: Option<SupportedAPIsFromClient> =
|
||||
SupportedAPIsFromClient::from_endpoint(&request_path);
|
||||
self.client_api = supported_api;
|
||||
|
||||
// Debug: log provider, client API, resolved API, and request path
|
||||
|
|
@ -1131,8 +1162,9 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
|
||||
match self.client_api {
|
||||
Some(SupportedAPIs::OpenAIChatCompletions(_)) => {}
|
||||
Some(SupportedAPIs::AnthropicMessagesAPI(_)) => {}
|
||||
Some(SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {}
|
||||
Some(SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {}
|
||||
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {}
|
||||
_ => {
|
||||
let api_info = match &self.client_api {
|
||||
Some(api) => format!("{}", api),
|
||||
|
|
|
|||
|
|
@ -47,6 +47,9 @@ llm_providers:
|
|||
- model: ollama/llama3.1
|
||||
base_url: http://host.docker.internal:11434
|
||||
|
||||
# Grok (xAI) Models
|
||||
- model: xai/grok-4-0709
|
||||
access_key: $GROK_API_KEY
|
||||
|
||||
# Model aliases - friendly names that map to actual provider names
|
||||
model_aliases:
|
||||
|
|
@ -83,5 +86,9 @@ model_aliases:
|
|||
coding-model:
|
||||
target: us.amazon.nova-premier-v1:0
|
||||
|
||||
# Alias for grok testing
|
||||
arch.grok.v1:
|
||||
target: grok-4-0709
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
|
|
|
|||
|
|
@ -65,6 +65,10 @@ log running e2e tests for model alias routing
|
|||
log ========================================
|
||||
poetry run pytest test_model_alias_routing.py
|
||||
|
||||
log running e2e tests for openai responses api client
|
||||
log ========================================
|
||||
poetry run pytest test_openai_responses_api_client.py
|
||||
|
||||
log shutting down the weather_forecast demo
|
||||
log =======================================
|
||||
cd ../../demos/samples_python/weather_forecast
|
||||
|
|
|
|||
630
tests/e2e/test_openai_responses_api_client.py
Normal file
630
tests/e2e/test_openai_responses_api_client.py
Normal file
|
|
@ -0,0 +1,630 @@
|
|||
import openai
|
||||
import pytest
|
||||
import os
|
||||
import logging
|
||||
import sys
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LLM_GATEWAY_ENDPOINT = os.getenv(
|
||||
"LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1/chat/completions"
|
||||
)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# v1/responses API tests
|
||||
# -----------------------
|
||||
def test_openai_responses_api_non_streaming_passthrough():
|
||||
"""Build a v1/responses API request (pass-through) and ensure gateway accepts it"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
|
||||
|
||||
# Simple responses API request using a direct model (pass-through)
|
||||
resp = client.responses.create(
|
||||
model="gpt-4o", input="Hello via responses passthrough"
|
||||
)
|
||||
|
||||
# Print the response content - handle both responses format and chat completions format
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Model: {resp.model}")
|
||||
print(f"Output: {resp.output_text}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
# Minimal sanity checks
|
||||
assert resp is not None
|
||||
assert (
|
||||
getattr(resp, "id", None) is not None
|
||||
or getattr(resp, "output", None) is not None
|
||||
)
|
||||
|
||||
|
||||
def test_openai_responses_api_with_streaming_passthrough():
|
||||
"""Build a v1/responses API streaming request (pass-through) and ensure gateway accepts it"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
|
||||
|
||||
# Simple streaming responses API request using a direct model (pass-through)
|
||||
stream = client.responses.create(
|
||||
model="gpt-4o",
|
||||
input="Write a short haiku about coding",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Collect streamed content using the official Responses API streaming shape
|
||||
text_chunks = []
|
||||
final_message = None
|
||||
|
||||
for event in stream:
|
||||
# The Python SDK surfaces a high-level Responses streaming interface.
|
||||
# We rely on its typed helpers instead of digging into model_extra.
|
||||
if getattr(event, "type", None) == "response.output_text.delta" and getattr(
|
||||
event, "delta", None
|
||||
):
|
||||
# Each delta contains a text fragment
|
||||
text_chunks.append(event.delta)
|
||||
|
||||
# Track the final response message if provided by the SDK
|
||||
if getattr(event, "type", None) == "response.completed" and getattr(
|
||||
event, "response", None
|
||||
):
|
||||
final_message = event.response
|
||||
|
||||
full_content = "".join(text_chunks)
|
||||
|
||||
# Print the streaming response
|
||||
print(f"\n{'='*80}")
|
||||
print(
|
||||
f"Model: {getattr(final_message, 'model', 'unknown') if final_message else 'unknown'}"
|
||||
)
|
||||
print(f"Streamed Output: {full_content}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
assert len(text_chunks) > 0, "Should have received streaming text deltas"
|
||||
assert len(full_content) > 0, "Should have received content"
|
||||
|
||||
|
||||
def test_openai_responses_api_non_streaming_with_tools_passthrough():
|
||||
"""Responses API with a function/tool definition (pass-through)"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0)
|
||||
|
||||
# Define a simple tool/function for the Responses API
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "echo_tool",
|
||||
"description": "Echo back the provided input",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
resp = client.responses.create(
|
||||
model="gpt-5",
|
||||
input="Call the echo tool",
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
assert resp is not None
|
||||
assert (
|
||||
getattr(resp, "id", None) is not None
|
||||
or getattr(resp, "output", None) is not None
|
||||
)
|
||||
|
||||
|
||||
def test_openai_responses_api_with_streaming_with_tools_passthrough():
|
||||
"""Responses API with a function/tool definition (streaming, pass-through)"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "echo_tool",
|
||||
"description": "Echo back the provided input",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
stream = client.responses.create(
|
||||
model="gpt-5",
|
||||
input="Call the echo tool",
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
text_chunks = []
|
||||
tool_calls = []
|
||||
|
||||
for event in stream:
|
||||
etype = getattr(event, "type", None)
|
||||
|
||||
# Collect streamed text output
|
||||
if etype == "response.output_text.delta" and getattr(event, "delta", None):
|
||||
text_chunks.append(event.delta)
|
||||
|
||||
# Collect streamed tool call arguments
|
||||
if etype == "response.function_call_arguments.delta" and getattr(
|
||||
event, "delta", None
|
||||
):
|
||||
tool_calls.append(event.delta)
|
||||
|
||||
full_text = "".join(text_chunks)
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("Responses tools streaming test")
|
||||
print(f"Streamed text: {full_text}")
|
||||
print(f"Tool call argument chunks: {len(tool_calls)}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
# We expect either streamed text output or streamed tool-call arguments
|
||||
assert (
|
||||
full_text or tool_calls
|
||||
), "Expected streamed text or tool call argument deltas from Responses tools stream"
|
||||
|
||||
|
||||
def test_openai_responses_api_non_streaming_upstream_chat_completions():
|
||||
"""Send a v1/responses request using the grok alias to verify translation/routing"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
|
||||
|
||||
resp = client.responses.create(
|
||||
model="arch.grok.v1", input="Hello, translate this via grok alias"
|
||||
)
|
||||
|
||||
# Print the response content - handle both responses format and chat completions format
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Model: {resp.model}")
|
||||
print(f"Output: {resp.output_text}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
assert resp is not None
|
||||
assert resp.id is not None
|
||||
|
||||
|
||||
def test_openai_responses_api_with_streaming_upstream_chat_completions():
|
||||
"""Build a v1/responses API streaming request (pass-through) and ensure gateway accepts it"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
|
||||
|
||||
# Simple streaming responses API request using a direct model (pass-through)
|
||||
stream = client.responses.create(
|
||||
model="arch.grok.v1",
|
||||
input="Write a short haiku about coding",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Collect streamed content using the official Responses API streaming shape
|
||||
text_chunks = []
|
||||
final_message = None
|
||||
|
||||
for event in stream:
|
||||
# The Python SDK surfaces a high-level Responses streaming interface.
|
||||
# We rely on its typed helpers instead of digging into model_extra.
|
||||
if getattr(event, "type", None) == "response.output_text.delta" and getattr(
|
||||
event, "delta", None
|
||||
):
|
||||
# Each delta contains a text fragment
|
||||
text_chunks.append(event.delta)
|
||||
|
||||
# Track the final response message if provided by the SDK
|
||||
if getattr(event, "type", None) == "response.completed" and getattr(
|
||||
event, "response", None
|
||||
):
|
||||
final_message = event.response
|
||||
|
||||
full_content = "".join(text_chunks)
|
||||
|
||||
# Print the streaming response
|
||||
print(f"\n{'='*80}")
|
||||
print(
|
||||
f"Model: {getattr(final_message, 'model', 'unknown') if final_message else 'unknown'}"
|
||||
)
|
||||
print(f"Streamed Output: {full_content}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
assert len(text_chunks) > 0, "Should have received streaming text deltas"
|
||||
assert len(full_content) > 0, "Should have received content"
|
||||
|
||||
|
||||
def test_openai_responses_api_non_streaming_with_tools_upstream_chat_completions():
|
||||
"""Responses API wioutputling routed to grok via alias"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "echo_tool",
|
||||
"description": "Echo back the provided input",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
resp = client.responses.create(
|
||||
model="arch.grok.v1",
|
||||
input="Call the echo tool",
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
assert resp.id is not None
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Model: {resp.model}")
|
||||
print(f"Output: {resp.output_text}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
|
||||
def test_openai_responses_api_streaming_with_tools_upstream_chat_completions():
|
||||
"""Responses API with a function/tool definition (streaming, pass-through)"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "echo_tool",
|
||||
"description": "Echo back the provided input",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
stream = client.responses.create(
|
||||
model="arch.grok.v1",
|
||||
input="Call the echo tool",
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
text_chunks = []
|
||||
tool_calls = []
|
||||
|
||||
for event in stream:
|
||||
etype = getattr(event, "type", None)
|
||||
|
||||
# Collect streamed text output
|
||||
if etype == "response.output_text.delta" and getattr(event, "delta", None):
|
||||
text_chunks.append(event.delta)
|
||||
|
||||
# Collect streamed tool call arguments
|
||||
if etype == "response.function_call_arguments.delta" and getattr(
|
||||
event, "delta", None
|
||||
):
|
||||
tool_calls.append(event.delta)
|
||||
|
||||
full_text = "".join(text_chunks)
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("Responses tools streaming test")
|
||||
print(f"Streamed text: {full_text}")
|
||||
print(f"Tool call argument chunks: {len(tool_calls)}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
# We expect either streamed text output or streamed tool-call arguments
|
||||
assert (
|
||||
full_text or tool_calls
|
||||
), "Expected streamed text or tool call argument deltas from Responses tools stream"
|
||||
|
||||
|
||||
def test_openai_responses_api_non_streaming_upstream_bedrock():
|
||||
"""Send a v1/responses request using the coding-model alias to verify Bedrock translation/routing"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
|
||||
|
||||
resp = client.responses.create(
|
||||
model="coding-model",
|
||||
input="Hello, translate this via coding-model alias to Bedrock",
|
||||
)
|
||||
|
||||
# Print the response content - handle both responses format and chat completions format
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Model: {resp.model}")
|
||||
print(f"Output: {resp.output_text}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
assert resp is not None
|
||||
assert resp.id is not None
|
||||
|
||||
|
||||
def test_openai_responses_api_with_streaming_upstream_bedrock():
|
||||
"""Build a v1/responses API streaming request routed to Bedrock via coding-model alias"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
|
||||
|
||||
# Simple streaming responses API request using coding-model alias
|
||||
stream = client.responses.create(
|
||||
model="coding-model",
|
||||
input="Write a short haiku about coding",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Collect streamed content using the official Responses API streaming shape
|
||||
text_chunks = []
|
||||
final_message = None
|
||||
|
||||
for event in stream:
|
||||
# The Python SDK surfaces a high-level Responses streaming interface.
|
||||
# We rely on its typed helpers instead of digging into model_extra.
|
||||
if getattr(event, "type", None) == "response.output_text.delta" and getattr(
|
||||
event, "delta", None
|
||||
):
|
||||
# Each delta contains a text fragment
|
||||
text_chunks.append(event.delta)
|
||||
|
||||
# Track the final response message if provided by the SDK
|
||||
if getattr(event, "type", None) == "response.completed" and getattr(
|
||||
event, "response", None
|
||||
):
|
||||
final_message = event.response
|
||||
|
||||
full_content = "".join(text_chunks)
|
||||
|
||||
# Print the streaming response
|
||||
print(f"\n{'='*80}")
|
||||
print(
|
||||
f"Model: {getattr(final_message, 'model', 'unknown') if final_message else 'unknown'}"
|
||||
)
|
||||
print(f"Streamed Output: {full_content}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
assert len(text_chunks) > 0, "Should have received streaming text deltas"
|
||||
assert len(full_content) > 0, "Should have received content"
|
||||
|
||||
|
||||
def test_openai_responses_api_non_streaming_with_tools_upstream_bedrock():
|
||||
"""Responses API with tools routed to Bedrock via coding-model alias"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "echo_tool",
|
||||
"description": "Echo back the provided input",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
resp = client.responses.create(
|
||||
model="coding-model",
|
||||
input="Call the echo tool",
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
assert resp.id is not None
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Model: {resp.model}")
|
||||
print(f"Output: {resp.output_text}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
|
||||
def test_openai_responses_api_streaming_with_tools_upstream_bedrock():
|
||||
"""Responses API with a function/tool definition streaming to Bedrock via coding-model alias"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "echo_tool",
|
||||
"description": "Echo back the provided input",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
stream = client.responses.create(
|
||||
model="coding-model",
|
||||
input="Call the echo tool",
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
text_chunks = []
|
||||
tool_calls = []
|
||||
|
||||
for event in stream:
|
||||
etype = getattr(event, "type", None)
|
||||
|
||||
# Collect streamed text output
|
||||
if etype == "response.output_text.delta" and getattr(event, "delta", None):
|
||||
text_chunks.append(event.delta)
|
||||
|
||||
# Collect streamed tool call arguments
|
||||
if etype == "response.function_call_arguments.delta" and getattr(
|
||||
event, "delta", None
|
||||
):
|
||||
tool_calls.append(event.delta)
|
||||
|
||||
full_text = "".join(text_chunks)
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("Responses tools streaming test (Bedrock)")
|
||||
print(f"Streamed text: {full_text}")
|
||||
print(f"Tool call argument chunks: {len(tool_calls)}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
# We expect either streamed text output or streamed tool-call arguments
|
||||
assert (
|
||||
full_text or tool_calls
|
||||
), "Expected streamed text or tool call argument deltas from Responses tools stream"
|
||||
|
||||
|
||||
def test_openai_responses_api_non_streaming_upstream_anthropic():
|
||||
"""Send a v1/responses request using the grok alias to verify translation/routing"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
|
||||
|
||||
resp = client.responses.create(
|
||||
model="claude-sonnet-4-20250514", input="Hello, translate this via grok alias"
|
||||
)
|
||||
|
||||
# Print the response content - handle both responses format and chat completions format
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Model: {resp.model}")
|
||||
print(f"Output: {resp.output_text}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
assert resp is not None
|
||||
assert resp.id is not None
|
||||
|
||||
|
||||
def test_openai_responses_api_with_streaming_upstream_anthropic():
|
||||
"""Build a v1/responses API streaming request (pass-through) and ensure gateway accepts it"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
|
||||
|
||||
# Simple streaming responses API request using a direct model (pass-through)
|
||||
stream = client.responses.create(
|
||||
model="claude-sonnet-4-20250514",
|
||||
input="Write a short haiku about coding",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Collect streamed content using the official Responses API streaming shape
|
||||
text_chunks = []
|
||||
final_message = None
|
||||
|
||||
for event in stream:
|
||||
# The Python SDK surfaces a high-level Responses streaming interface.
|
||||
# We rely on its typed helpers instead of digging into model_extra.
|
||||
if getattr(event, "type", None) == "response.output_text.delta" and getattr(
|
||||
event, "delta", None
|
||||
):
|
||||
# Each delta contains a text fragment
|
||||
text_chunks.append(event.delta)
|
||||
|
||||
# Track the final response message if provided by the SDK
|
||||
if getattr(event, "type", None) == "response.completed" and getattr(
|
||||
event, "response", None
|
||||
):
|
||||
final_message = event.response
|
||||
|
||||
full_content = "".join(text_chunks)
|
||||
|
||||
# Print the streaming response
|
||||
print(f"\n{'='*80}")
|
||||
print(
|
||||
f"Model: {getattr(final_message, 'model', 'unknown') if final_message else 'unknown'}"
|
||||
)
|
||||
print(f"Streamed Output: {full_content}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
assert len(text_chunks) > 0, "Should have received streaming text deltas"
|
||||
assert len(full_content) > 0, "Should have received content"
|
||||
|
||||
|
||||
def test_openai_responses_api_non_streaming_with_tools_upstream_anthropic():
|
||||
"""Responses API with tools routed to grok via alias"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "echo_tool",
|
||||
"description": "Echo back the provided input: hello_world",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
resp = client.responses.create(
|
||||
model="claude-sonnet-4-20250514",
|
||||
input="Call the echo tool",
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
assert resp.id is not None
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Model: {resp.model}")
|
||||
print(f"Output: {resp.output_text}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
|
||||
def test_openai_responses_api_streaming_with_tools_upstream_anthropic():
|
||||
"""Responses API with a function/tool definition (streaming, pass-through)"""
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "echo_tool",
|
||||
"description": "Echo back the provided input: hello_world",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
stream = client.responses.create(
|
||||
model="claude-sonnet-4-20250514",
|
||||
input="Call the echo tool",
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
text_chunks = []
|
||||
tool_calls = []
|
||||
|
||||
for event in stream:
|
||||
etype = getattr(event, "type", None)
|
||||
|
||||
# Collect streamed text output
|
||||
if etype == "response.output_text.delta" and getattr(event, "delta", None):
|
||||
text_chunks.append(event.delta)
|
||||
|
||||
# Collect streamed tool call arguments
|
||||
if etype == "response.function_call_arguments.delta" and getattr(
|
||||
event, "delta", None
|
||||
):
|
||||
tool_calls.append(event.delta)
|
||||
|
||||
full_text = "".join(text_chunks)
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("Responses tools streaming test")
|
||||
print(f"Streamed text: {full_text}")
|
||||
print(f"Tool call argument chunks: {len(tool_calls)}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
# We expect either streamed text output or streamed tool-call arguments
|
||||
assert (
|
||||
full_text or tool_calls
|
||||
), "Expected streamed text or tool call argument deltas from Responses tools stream"
|
||||
Loading…
Add table
Add a link
Reference in a new issue