diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index 5988c27e..be601264 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -59,6 +59,7 @@ jobs: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} AWS_BEARER_TOKEN_BEDROCK: ${{ secrets.AWS_BEARER_TOKEN_BEDROCK }} + GROK_API_KEY : ${{ secrets.GROK_API_KEY }} run: | python -mvenv venv source venv/bin/activate && cd tests/e2e && bash run_e2e_tests.sh diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 5797d5a2..b243c365 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -912,10 +912,12 @@ version = "0.1.0" dependencies = [ "aws-smithy-eventstream", "bytes", + "log", "serde", "serde_json", "serde_with", "thiserror 2.0.12", + "uuid", ] [[package]] diff --git a/crates/brightstaff/src/handlers/router.rs b/crates/brightstaff/src/handlers/router.rs index d27bab55..c369729a 100644 --- a/crates/brightstaff/src/handlers/router.rs +++ b/crates/brightstaff/src/handlers/router.rs @@ -3,7 +3,7 @@ use common::configuration::{ModelAlias, ModelUsagePreference}; use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER}; use hermesllm::apis::openai::ChatCompletionsRequest; use hermesllm::clients::endpoints::SupportedUpstreamAPIs; -use hermesllm::clients::SupportedAPIs; +use hermesllm::clients::SupportedAPIsFromClient; use hermesllm::{ProviderRequest, ProviderRequestType}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full}; @@ -39,7 +39,7 @@ pub async fn router_chat( let mut client_request = match ProviderRequestType::try_from(( &chat_request_bytes[..], - &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap(), + &SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(), )) { Ok(request) => request, Err(err) => { @@ -58,7 +58,7 @@ pub async fn router_chat( let resolved_model = if let Some(model_aliases) = model_aliases.as_ref() { if let Some(model_alias) = model_aliases.get(&model_from_request) { debug!( - "Model Alias: 'From {}' -> 'To{}'", + "Model Alias: 'From {}' -> 'To {}'", model_from_request, model_alias.target ); model_alias.target.clone() @@ -91,10 +91,11 @@ pub async fn router_chat( Ok( ProviderRequestType::MessagesRequest(_) | ProviderRequestType::BedrockConverse(_) - | ProviderRequestType::BedrockConverseStream(_), + | ProviderRequestType::BedrockConverseStream(_) + | ProviderRequestType::ResponsesAPIRequest(_), ) => { // This should not happen after conversion to OpenAI format - warn!("Unexpected: got MessagesRequest after converting to OpenAI format"); + warn!("Unexpected: got non-ChatCompletions request after converting to OpenAI format"); let err_msg = "Request conversion failed".to_string(); let mut bad_request = Response::new(full(err_msg)); *bad_request.status_mut() = StatusCode::BAD_REQUEST; diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 87bdea36..73f5ef58 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -6,7 +6,7 @@ use brightstaff::router::llm_router::RouterService; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; use common::configuration::Configuration; -use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH}; +use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH}; use http_body_util::{combinators::BoxBody, BodyExt, Empty}; use hyper::body::Incoming; use hyper::server::conn::http1; @@ -123,7 +123,7 @@ async fn main() -> Result<(), Box> { async move { match (req.method(), req.uri().path()) { - (&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => { + (&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => { let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path()); router_chat(req, router_service, fully_qualified_url, model_aliases) diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 8edbff1a..12d35ab4 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -13,6 +13,7 @@ pub const MESSAGES_KEY: &str = "messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; pub const ARCH_IS_STREAMING_HEADER: &str = "x-arch-streaming-request"; pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; +pub const OPENAI_RESPONSES_API_PATH: &str = "/v1/responses"; pub const MESSAGES_PATH: &str = "/v1/messages"; pub const HEALTHZ_PATH: &str = "/healthz"; pub const X_ARCH_STATE_HEADER: &str = "x-arch-state"; diff --git a/crates/hermesllm/Cargo.toml b/crates/hermesllm/Cargo.toml index ab2390bf..d877fc00 100644 --- a/crates/hermesllm/Cargo.toml +++ b/crates/hermesllm/Cargo.toml @@ -10,3 +10,5 @@ serde_with = {version = "3.12.0", features = ["base64"]} thiserror = "2.0.12" aws-smithy-eventstream = "0.60" bytes = "1.10" +uuid = { version = "1.11", features = ["v4"] } +log = "0.4" diff --git a/crates/hermesllm/src/apis/amazon_bedrock.rs b/crates/hermesllm/src/apis/amazon_bedrock.rs index eb1f3ddf..252bd0f1 100644 --- a/crates/hermesllm/src/apis/amazon_bedrock.rs +++ b/crates/hermesllm/src/apis/amazon_bedrock.rs @@ -7,7 +7,7 @@ use thiserror::Error; use super::ApiDefinition; use crate::providers::request::{ProviderRequest, ProviderRequestError}; -use crate::providers::response::ProviderStreamResponse; +use crate::providers::streaming_response::ProviderStreamResponse; // ============================================================================ // AMAZON BEDROCK CONVERSE API ENUMERATION diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index f91b381c..7e1951e4 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -6,7 +6,8 @@ use std::collections::HashMap; use super::ApiDefinition; use crate::providers::request::{ProviderRequest, ProviderRequestError}; -use crate::providers::response::{ProviderResponse, ProviderStreamResponse}; +use crate::providers::response::ProviderResponse; +use crate::providers::streaming_response::ProviderStreamResponse; use crate::transforms::lib::ExtractText; use crate::MESSAGES_PATH; diff --git a/crates/hermesllm/src/apis/mod.rs b/crates/hermesllm/src/apis/mod.rs index 7d84e3ab..ea056392 100644 --- a/crates/hermesllm/src/apis/mod.rs +++ b/crates/hermesllm/src/apis/mod.rs @@ -1,8 +1,8 @@ pub mod amazon_bedrock; -pub mod amazon_bedrock_binary_frame; pub mod anthropic; pub mod openai; -pub mod sse; +pub mod openai_responses; +pub mod streaming_shapes; // Explicit exports to avoid naming conflicts pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest}; @@ -88,8 +88,9 @@ mod tests { fn test_all_variants_method() { // Test that all_variants returns the expected variants let openai_variants = OpenAIApi::all_variants(); - assert_eq!(openai_variants.len(), 1); + assert_eq!(openai_variants.len(), 2); assert!(openai_variants.contains(&OpenAIApi::ChatCompletions)); + assert!(openai_variants.contains(&OpenAIApi::Responses)); let anthropic_variants = AnthropicApi::all_variants(); assert_eq!(anthropic_variants.len(), 1); diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 44b64485..d7f7a07d 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -7,9 +7,10 @@ use thiserror::Error; use super::ApiDefinition; use crate::providers::request::{ProviderRequest, ProviderRequestError}; -use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage}; +use crate::providers::response::{ProviderResponse, TokenUsage}; +use crate::providers::streaming_response::ProviderStreamResponse; use crate::transforms::lib::ExtractText; -use crate::CHAT_COMPLETIONS_PATH; +use crate::{CHAT_COMPLETIONS_PATH, OPENAI_RESPONSES_API_PATH}; // ============================================================================ // OPENAI API ENUMERATION @@ -19,6 +20,7 @@ use crate::CHAT_COMPLETIONS_PATH; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum OpenAIApi { ChatCompletions, + Responses, // Future APIs can be added here: // Embeddings, // FineTuning, @@ -29,12 +31,14 @@ impl ApiDefinition for OpenAIApi { fn endpoint(&self) -> &'static str { match self { OpenAIApi::ChatCompletions => CHAT_COMPLETIONS_PATH, + OpenAIApi::Responses => OPENAI_RESPONSES_API_PATH, } } fn from_endpoint(endpoint: &str) -> Option { match endpoint { CHAT_COMPLETIONS_PATH => Some(OpenAIApi::ChatCompletions), + OPENAI_RESPONSES_API_PATH => Some(OpenAIApi::Responses), _ => None, } } @@ -42,23 +46,26 @@ impl ApiDefinition for OpenAIApi { fn supports_streaming(&self) -> bool { match self { OpenAIApi::ChatCompletions => true, + OpenAIApi::Responses => true, } } fn supports_tools(&self) -> bool { match self { OpenAIApi::ChatCompletions => true, + OpenAIApi::Responses => true, } } fn supports_vision(&self) -> bool { match self { OpenAIApi::ChatCompletions => true, + OpenAIApi::Responses => true, } } fn all_variants() -> Vec { - vec![OpenAIApi::ChatCompletions] + vec![OpenAIApi::ChatCompletions, OpenAIApi::Responses] } } @@ -1077,8 +1084,9 @@ mod tests { // Test all_variants let all_variants = OpenAIApi::all_variants(); - assert_eq!(all_variants.len(), 1); - assert_eq!(all_variants[0], OpenAIApi::ChatCompletions); + assert_eq!(all_variants.len(), 2); + assert!(all_variants.contains(&OpenAIApi::ChatCompletions)); + assert!(all_variants.contains(&OpenAIApi::Responses)); } #[test] diff --git a/crates/hermesllm/src/apis/openai_responses.rs b/crates/hermesllm/src/apis/openai_responses.rs new file mode 100644 index 00000000..4f0cf663 --- /dev/null +++ b/crates/hermesllm/src/apis/openai_responses.rs @@ -0,0 +1,1386 @@ +use std::collections::HashMap; +use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; +use crate::providers::request::{ProviderRequest, ProviderRequestError}; + +impl TryFrom<&[u8]> for ResponsesAPIRequest { + type Error = serde_json::Error; + + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes) + } +} + +/// Parameterized conversion for ResponsesAPIResponse +impl TryFrom<&[u8]> for ResponsesAPIResponse { + type Error = crate::apis::openai::OpenAIStreamError; + + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(crate::apis::openai::OpenAIStreamError::from) + } +} + +// ============================================================================ +// Request Structs - CreateResponse +// ============================================================================ + +/// Request to create a model response +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponsesAPIRequest { + /// The model to use for generating the response + pub model: String, + + /// Text, image, or file inputs to the model + pub input: InputParam, + + /// Specify additional output data to include in the model response + pub include: Option>, + + /// Whether to allow the model to run tool calls in parallel + pub parallel_tool_calls: Option, + + /// Whether to store the generated model response for later retrieval via API + pub store: Option, + + /// A system (or developer) message inserted into the model's context + pub instructions: Option, + + /// If set to true, the model response data will be streamed to the client + pub stream: Option, + + /// Stream options configuration + pub stream_options: Option, + + /// Conversation state + pub conversation: Option, + + /// Tools available to the model + pub tools: Option>, + + /// Tool choice option + pub tool_choice: Option, + + /// Maximum number of output tokens + pub max_output_tokens: Option, + + /// Temperature for sampling (0-2) + pub temperature: Option, + + /// Top-p nucleus sampling parameter + pub top_p: Option, + + /// Metadata for the response + pub metadata: Option>, + + /// Previous response ID for conversation continuation + pub previous_response_id: Option, + + /// Response modalities + pub modalities: Option>, + + /// Audio output configuration + pub audio: Option, + + /// Text output format configuration + pub text: Option, + + /// Reasoning effort level + pub reasoning_effort: Option, + + /// Truncation strategy + pub truncation: Option, + + /// User identifier + pub user: Option, + + /// Maximum number of tool calls + pub max_tool_calls: Option, + + /// Service tier + pub service_tier: Option, + + /// Whether to run in background + pub background: Option, + + /// Number of top logprobs to include + pub top_logprobs: Option, +} + +/// Input parameter - can be a simple string or array of input items +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum InputParam { + /// Simple text input + Text(String), + /// Array of input items + Items(Vec), +} + +/// Input item discriminated by type +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum InputItem { + /// Input message + Message(InputMessage), +} + +/// Input message with role and content +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InputMessage { + /// Message role + pub role: MessageRole, + /// Message content + pub content: Vec, +} + +/// Message roles +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum MessageRole { + User, + Assistant, + System, + Developer, +} + +/// Input content types +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum InputContent { + /// Text input + InputText { + text: String, + }, + /// Image input via URL + InputImage { + image_url: String, + detail: Option, + }, + /// File input via URL + InputFile { + file_url: String, + }, + /// Audio input + InputAudio { + data: Option, + format: Option, + }, +} + +/// Modality options +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Modality { + Text, + Audio, +} + +/// Audio configuration +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AudioConfig { + /// Voice to use for audio output + pub voice: String, + /// Audio output format + pub format: Option, +} + +/// Text configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TextConfig { + /// Text format configuration + pub format: TextFormat, +} + +/// Text format +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum TextFormat { + Text, + JsonObject, + JsonSchema { + json_schema: serde_json::Value, + }, +} + +/// Reasoning effort levels +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ReasoningEffort { + Low, + Medium, + High, +} + +/// Include enum for additional output data +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum IncludeEnum { + #[serde(rename = "web_search_call.action.sources")] + WebSearchCallActionSources, + #[serde(rename = "code_interpreter_call.outputs")] + CodeInterpreterCallOutputs, + #[serde(rename = "computer_call_output.output.image_url")] + ComputerCallOutputImageUrl, + #[serde(rename = "file_search_call.results")] + FileSearchCallResults, + #[serde(rename = "message.input_image.image_url")] + MessageInputImageImageUrl, + #[serde(rename = "message.output_text.logprobs")] + MessageOutputTextLogprobs, + #[serde(rename = "reasoning.encrypted_content")] + ReasoningEncryptedContent, +} + +/// Response stream options +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseStreamOptions { + /// Whether to include usage in stream + pub include_usage: Option, +} + +/// Conversation parameter +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationParam { + /// Conversation ID + pub id: Option, +} + +/// Tool definitions +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Tool { + /// Function tool - flat structure in Responses API + Function { + name: String, + description: Option, + parameters: Option, + strict: Option, + }, + /// File search tool + FileSearch { + vector_store_ids: Option>, + max_num_results: Option, + ranking_options: Option, + filters: Option, + }, + /// Web search tool + WebSearchPreview { + domains: Option>, + search_context_size: Option, + user_location: Option, + }, + /// Code interpreter tool + CodeInterpreter, + /// Computer tool + Computer { + display_width_px: Option, + display_height_px: Option, + display_number: Option, + }, +} + +/// Ranking options for file search +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RankingOptions { + /// Ranker type + pub ranker: String, + /// Score threshold + pub score_threshold: Option, +} + +/// User location for web search +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserLocation { + #[serde(rename = "type")] + pub location_type: String, + pub city: Option, + pub country: Option, + pub region: Option, + pub timezone: Option, +} + +/// Tool choice options +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + /// Auto, none, or required + String(String), + /// Named tool choice + Named { + #[serde(rename = "type")] + tool_type: String, + function: NamedFunction, + }, +} + +/// Named function for tool choice +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NamedFunction { + pub name: String, +} + +// ============================================================================ +// Response Structs - Response Object +// ============================================================================ + +/// The response object returned from the API +/// Request to create a model response +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponsesAPIResponse { + /// Unique identifier for this Response + pub id: String, + + /// The object type - always "response" + pub object: String, + + /// Unix timestamp (in seconds) of when this Response was created + pub created_at: i64, + + /// The status of the response generation + pub status: ResponseStatus, + + /// Error information if the response failed + pub error: Option, + + /// Details about why the response is incomplete + pub incomplete_details: Option, + + /// System/developer instructions + pub instructions: Option, + + /// The model used + pub model: String, + + /// An array of content items generated by the model + pub output: Vec, + + /// Usage statistics + pub usage: Option, + + /// Whether to allow parallel tool calls + pub parallel_tool_calls: bool, + + /// Conversation state + pub conversation: Option, + + /// Previous response ID + pub previous_response_id: Option, + + /// Tools available + pub tools: Vec, + + /// Tool choice setting + pub tool_choice: String, + + /// Temperature setting + pub temperature: f32, + + /// Top-p setting + pub top_p: f32, + + /// Metadata + pub metadata: HashMap, + + /// Truncation setting + pub truncation: Option, + + /// Maximum output tokens + pub max_output_tokens: Option, + + /// Reasoning configuration + pub reasoning: Option, + + /// Whether response is stored + pub store: Option, + + /// Text configuration + pub text: Option, + + /// Audio configuration + pub audio: Option, + + /// Modalities + pub modalities: Option>, + + /// Service tier + pub service_tier: Option, + + /// Background execution + pub background: Option, + + /// Top logprobs count + pub top_logprobs: Option, + + /// Maximum tool calls + pub max_tool_calls: Option, +} + +/// Response status +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum ResponseStatus { + Completed, + Failed, + InProgress, + Cancelled, + Queued, + Incomplete, +} + +/// Response error information +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseError { + /// Error code + pub code: ResponseErrorCode, + /// Human-readable error message + pub message: String, +} + +/// Response error codes +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ResponseErrorCode { + ServerError, + RateLimitExceeded, + InvalidPrompt, + VectorStoreTimeout, + InvalidImage, + InvalidImageFormat, + InvalidBase64Image, + InvalidImageUrl, + ImageTooLarge, + ImageTooSmall, + ImageParseError, + ImageContentPolicyViolation, + InvalidImageMode, + ImageFileTooLarge, + UnsupportedImageMediaType, + EmptyImageFile, + FailedToDownloadImage, + ImageFileNotFound, +} + +/// Incomplete details +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IncompleteDetails { + /// The reason why the response is incomplete + pub reason: IncompleteReason, +} + +/// Incomplete reasons +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum IncompleteReason { + MaxOutputTokens, + ContentFilter, +} + +/// Output items from the model +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum OutputItem { + /// Output message + Message { + id: String, + status: OutputItemStatus, + role: String, + content: Vec, + }, + /// Function tool call + FunctionCall { + id: String, + status: OutputItemStatus, + call_id: String, + name: Option, + arguments: Option, + }, + /// Function call output + FunctionCallOutput { + id: String, + call_id: String, + output: String, + status: Option, + }, + /// File search tool call + FileSearchCall { + id: String, + status: OutputItemStatus, + queries: Option>, + results: Option>, + }, + /// Web search tool call + WebSearchCall { + id: String, + status: OutputItemStatus, + }, + /// Code interpreter tool call + CodeInterpreterCall { + id: String, + status: OutputItemStatus, + code: Option, + outputs: Option>, + }, + /// Computer tool call + ComputerCall { + id: String, + status: OutputItemStatus, + action: Option, + }, + /// Computer call output + ComputerCallOutput { + id: String, + call_id: String, + output: Option, + status: Option, + }, + /// Custom tool call + CustomToolCall { + id: String, + status: OutputItemStatus, + call_id: String, + input: Option, + }, + /// Custom tool call output + CustomToolCallOutput { + id: String, + call_id: String, + output: String, + status: Option, + }, + /// Reasoning item + Reasoning { + id: String, + summary: Vec, + }, +} + +/// Output item status +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum OutputItemStatus { + InProgress, + Completed, + Incomplete, +} + +/// Output content types +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum OutputContent { + /// Text output + OutputText { + text: String, + annotations: Vec, + logprobs: Option>, + }, + /// Audio output + OutputAudio { + data: Option, + transcript: Option, + }, + /// Refusal output + Refusal { + refusal: String, + }, +} + +/// Annotations for output text +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Annotation { + /// File citation + FileCitation { + index: i32, + file_id: String, + filename: String, + quote: Option, + }, + /// URL citation + UrlCitation { + start_index: i32, + end_index: i32, + url: String, + title: String, + }, +} + +/// Log probability information +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LogProb { + /// The token + pub token: String, + /// Log probability value + pub logprob: f32, + /// Token bytes + pub bytes: Vec, +} + +/// File search result +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FileSearchResult { + /// File ID + pub file_id: String, + /// File name + pub filename: String, + /// Score + pub score: Option, + /// Content excerpt + pub content: Option, +} + +/// Code interpreter output +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum CodeInterpreterOutput { + /// Text output + Text { + text: String, + }, + /// Image output + Image { + image: String, + }, +} + +/// Response usage statistics +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseUsage { + /// Input tokens used + pub input_tokens: i32, + /// Output tokens generated + pub output_tokens: i32, + /// Total tokens (input + output) + pub total_tokens: i32, + /// Input token details + pub input_tokens_details: Option, + /// Output token details + pub output_tokens_details: Option, +} + +impl crate::providers::response::TokenUsage for ResponseUsage { + fn completion_tokens(&self) -> usize { + self.output_tokens as usize + } + + fn prompt_tokens(&self) -> usize { + self.input_tokens as usize + } + + fn total_tokens(&self) -> usize { + self.total_tokens as usize + } +} + +/// Token details +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenDetails { + /// Cached tokens + pub cached_tokens: i32, +} + +/// Output token details +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OutputTokenDetails { + /// Reasoning tokens + pub reasoning_tokens: i32, +} + +/// Reasoning configuration and summary +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Reasoning { + /// Reasoning effort level + pub effort: Option, + /// Summary of reasoning + pub summary: Option, +} + +/// Conversation object +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Conversation { + /// Conversation ID + pub id: String, + /// Conversation object type + pub object: String, +} + +// ============================================================================ +// Streaming Response Events +// ============================================================================ + +/// Stream events for responses +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ResponsesAPIStreamEvent { + /// Response created + #[serde(rename = "response.created")] + ResponseCreated { + response: ResponsesAPIResponse, + sequence_number: i32, + }, + + /// Response in progress + #[serde(rename = "response.in_progress")] + ResponseInProgress { + response: ResponsesAPIResponse, + sequence_number: i32, + }, + + /// Response completed + #[serde(rename = "response.completed")] + ResponseCompleted { + response: ResponsesAPIResponse, + sequence_number: i32, + }, + + /// Output item added + #[serde(rename = "response.output_item.added")] + ResponseOutputItemAdded { + output_index: i32, + item: OutputItem, + sequence_number: i32, + }, + + /// Output item done + #[serde(rename = "response.output_item.done")] + ResponseOutputItemDone { + output_index: i32, + item: OutputItem, + sequence_number: i32, + }, + + /// Content part added + #[serde(rename = "response.content_part.added")] + ResponseContentPartAdded { + item_id: String, + output_index: i32, + content_index: i32, + part: OutputContent, + sequence_number: i32, + }, + + /// Content part done + #[serde(rename = "response.content_part.done")] + ResponseContentPartDone { + item_id: String, + output_index: i32, + content_index: i32, + part: OutputContent, + sequence_number: i32, + }, + + /// Output text delta (incremental text streaming) + #[serde(rename = "response.output_text.delta")] + ResponseOutputTextDelta { + item_id: String, + output_index: i32, + content_index: i32, + delta: String, + logprobs: Vec, + obfuscation: Option, + sequence_number: i32, + }, + + /// Output text done (final complete text) + #[serde(rename = "response.output_text.done")] + ResponseOutputTextDone { + item_id: String, + output_index: i32, + content_index: i32, + text: String, + logprobs: Vec, + sequence_number: i32, + }, + + /// Audio delta + #[serde(rename = "response.audio.delta")] + ResponseAudioDelta { + item_id: Option, + output_index: Option, + content_index: Option, + delta: String, + sequence_number: i32, + }, + + /// Audio done + #[serde(rename = "response.audio.done")] + ResponseAudioDone { + item_id: Option, + output_index: Option, + content_index: Option, + sequence_number: i32, + }, + + /// Audio transcript delta + #[serde(rename = "response.audio_transcript.delta")] + ResponseAudioTranscriptDelta { + item_id: Option, + output_index: Option, + content_index: Option, + delta: String, + sequence_number: i32, + }, + + /// Audio transcript done + #[serde(rename = "response.audio_transcript.done")] + ResponseAudioTranscriptDone { + item_id: Option, + output_index: Option, + content_index: Option, + transcript: Option, + sequence_number: i32, + }, + + /// Function call arguments delta + #[serde(rename = "response.function_call_arguments.delta")] + ResponseFunctionCallArgumentsDelta { + output_index: i32, + item_id: String, + delta: String, + sequence_number: i32, + call_id: Option, + name: Option, + }, + + /// Function call arguments done + #[serde(rename = "response.function_call_arguments.done")] + ResponseFunctionCallArgumentsDone { + output_index: i32, + item_id: String, + arguments: String, + sequence_number: i32, + }, + + /// Code interpreter call code delta + #[serde(rename = "response.code_interpreter_call.code.delta")] + ResponseCodeInterpreterCallCodeDelta { + output_index: i32, + item_id: String, + delta: String, + sequence_number: i32, + }, + + /// Code interpreter call code done + #[serde(rename = "response.code_interpreter_call.code.done")] + ResponseCodeInterpreterCallCodeDone { + output_index: i32, + item_id: String, + code: String, + sequence_number: i32, + }, + + /// Code interpreter call in progress + #[serde(rename = "response.code_interpreter_call.in_progress")] + ResponseCodeInterpreterCallInProgress { + output_index: i32, + item_id: String, + sequence_number: i32, + }, + + /// Code interpreter call interpreting + #[serde(rename = "response.code_interpreter_call.interpreting")] + ResponseCodeInterpreterCallInterpreting { + output_index: i32, + item_id: String, + sequence_number: i32, + }, + + /// Code interpreter call completed + #[serde(rename = "response.code_interpreter_call.completed")] + ResponseCodeInterpreterCallCompleted { + output_index: i32, + item_id: String, + sequence_number: i32, + }, + + /// Custom tool call input delta + #[serde(rename = "response.custom_tool_call.input.delta")] + ResponseCustomToolCallInputDelta { + output_index: i32, + item_id: String, + delta: String, + sequence_number: i32, + }, + + /// Custom tool call input done + #[serde(rename = "response.custom_tool_call.input.done")] + ResponseCustomToolCallInputDone { + output_index: i32, + item_id: String, + input: String, + sequence_number: i32, + }, + + /// Error event + Error { + code: String, + message: String, + sequence_number: i32, + }, + + /// Done event (end of stream) + Done { + sequence_number: i32, + }, +} + +// ============================================================================ +// Additional Response Operations +// ============================================================================ + +/// Retrieve response request (GET /responses/{response_id}) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GetResponseRequest { + /// Response ID to retrieve + pub response_id: String, +} + +/// Delete response request (DELETE /responses/{response_id}) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeleteResponseRequest { + /// Response ID to delete + pub response_id: String, +} + +/// Delete response response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeleteResponseResponse { + /// Response ID that was deleted + pub id: String, + /// Object type + pub object: String, + /// Whether deletion was successful + pub deleted: bool, +} + +/// Cancel response request (POST /responses/{response_id}/cancel) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CancelResponseRequest { + /// Response ID to cancel + pub response_id: String, +} + +/// List input items request (GET /responses/{response_id}/input_items) +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ListInputItemsRequest { + /// Response ID + pub response_id: String, + /// Limit for pagination + pub limit: Option, + /// Order for pagination + pub order: Option, + /// After cursor for pagination + pub after: Option, + /// Before cursor for pagination + pub before: Option, +} + +/// List input items response +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ListInputItemsResponse { + /// Object type - always "list" + pub object: String, + /// Array of input items + pub data: Vec, + /// First ID in the list + pub first_id: Option, + /// Last ID in the list + pub last_id: Option, + /// Whether there are more items + pub has_more: bool, +} + +// ============================================================================ +// ProviderRequest Implementation +// ============================================================================ + +impl ProviderRequest for ResponsesAPIRequest { + fn model(&self) -> &str { + &self.model + } + + fn set_model(&mut self, model: String) { + self.model = model; + } + + fn is_streaming(&self) -> bool { + self.stream.unwrap_or_default() + } + + fn extract_messages_text(&self) -> String { + match &self.input { + InputParam::Text(text) => text.clone(), + InputParam::Items(items) => { + items.iter().fold(String::new(), |acc, item| { + match item { + InputItem::Message(msg) => { + let content_text = msg.content.iter().fold(String::new(), |acc, content| { + acc + " " + &match content { + InputContent::InputText { text } => text.clone(), + InputContent::InputImage { .. } => "[Image]".to_string(), + InputContent::InputFile { .. } => "[File]".to_string(), + InputContent::InputAudio { .. } => "[Audio]".to_string(), + } + }); + acc + " " + &content_text + } + } + }) + } + } + } + + fn get_recent_user_message(&self) -> Option { + match &self.input { + InputParam::Text(text) => Some(text.clone()), + InputParam::Items(items) => { + items.iter().rev().find_map(|item| { + match item { + InputItem::Message(msg) if matches!(msg.role, MessageRole::User) => { + // Extract text from the first text content + msg.content.iter().find_map(|content| { + match content { + InputContent::InputText { text } => Some(text.clone()), + _ => None, + } + }) + } + _ => None, + } + }) + } + } + } + + fn to_bytes(&self) -> Result, ProviderRequestError> { + serde_json::to_vec(&self).map_err(|e| ProviderRequestError { + message: format!("Failed to serialize Responses API request: {}", e), + source: Some(Box::new(e)), + }) + } + + fn metadata(&self) -> &Option> { + &self.metadata + } + + fn remove_metadata_key(&mut self, key: &str) -> bool { + if let Some(ref mut metadata) = self.metadata { + metadata.remove(key).is_some() + } else { + false + } + } +} + +// ============================================================================ +// Into Implementation for SSE Formatting +// ============================================================================ + +impl Into for ResponsesAPIStreamEvent { + fn into(self) -> String { + let transformed_json = serde_json::to_string(&self).unwrap_or_default(); + let event_type = match &self { + ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created", + ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress", + ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed", + ResponsesAPIStreamEvent::ResponseOutputItemAdded { .. } => "response.output_item.added", + ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done", + ResponsesAPIStreamEvent::ResponseContentPartAdded { .. } => { + "response.content_part.added" + } + ResponsesAPIStreamEvent::ResponseContentPartDone { .. } => "response.content_part.done", + ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta", + ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done", + ResponsesAPIStreamEvent::ResponseAudioDelta { .. } => "response.audio.delta", + ResponsesAPIStreamEvent::ResponseAudioDone { .. } => "response.audio.done", + ResponsesAPIStreamEvent::ResponseAudioTranscriptDelta { .. } => { + "response.audio_transcript.delta" + } + ResponsesAPIStreamEvent::ResponseAudioTranscriptDone { .. } => { + "response.audio_transcript.done" + } + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => { + "response.function_call_arguments.delta" + } + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => { + "response.function_call_arguments.done" + } + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCodeDelta { .. } => { + "response.code_interpreter_call.code.delta" + } + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCodeDone { .. } => { + "response.code_interpreter_call.code.done" + } + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallInProgress { .. } => { + "response.code_interpreter_call.in_progress" + } + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallInterpreting { .. } => { + "response.code_interpreter_call.interpreting" + } + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCompleted { .. } => { + "response.code_interpreter_call.completed" + } + ResponsesAPIStreamEvent::ResponseCustomToolCallInputDelta { .. } => { + "response.custom_tool_call.input.delta" + } + ResponsesAPIStreamEvent::ResponseCustomToolCallInputDone { .. } => { + "response.custom_tool_call.input.done" + } + ResponsesAPIStreamEvent::Error { .. } => "error", + ResponsesAPIStreamEvent::Done { .. } => "done", + }; + + let event = format!("event: {}\n", event_type); + let data = format!("data: {}\n\n", transformed_json); + event + &data + } +} + +// ============================================================================ +// ProviderStreamResponse Implementation +// ============================================================================ + +impl crate::providers::streaming_response::ProviderStreamResponse for ResponsesAPIStreamEvent { + fn content_delta(&self) -> Option<&str> { + match self { + ResponsesAPIStreamEvent::ResponseOutputTextDelta { delta, .. } => Some(delta), + ResponsesAPIStreamEvent::ResponseAudioDelta { delta, .. } => Some(delta), + ResponsesAPIStreamEvent::ResponseAudioTranscriptDelta { delta, .. } => Some(delta), + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { delta, .. } => { + Some(delta) + } + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCodeDelta { delta, .. } => { + Some(delta) + } + ResponsesAPIStreamEvent::ResponseCustomToolCallInputDelta { delta, .. } => Some(delta), + _ => None, + } + } + + fn is_final(&self) -> bool { + matches!( + self, + ResponsesAPIStreamEvent::ResponseCompleted { .. } + | ResponsesAPIStreamEvent::Done { .. } + ) + } + + fn role(&self) -> Option<&str> { + match self { + ResponsesAPIStreamEvent::ResponseOutputItemDone { item, .. } => match item { + OutputItem::Message { role, .. } => Some(role.as_str()), + _ => None, + }, + _ => None, + } + } + + fn event_type(&self) -> Option<&str> { + Some(match self { + ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created", + ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress", + ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed", + ResponsesAPIStreamEvent::ResponseOutputItemAdded { .. } => "response.output_item.added", + ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done", + ResponsesAPIStreamEvent::ResponseContentPartAdded { .. } => { + "response.content_part.added" + } + ResponsesAPIStreamEvent::ResponseContentPartDone { .. } => "response.content_part.done", + ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta", + ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done", + ResponsesAPIStreamEvent::ResponseAudioDelta { .. } => "response.audio.delta", + ResponsesAPIStreamEvent::ResponseAudioDone { .. } => "response.audio.done", + ResponsesAPIStreamEvent::ResponseAudioTranscriptDelta { .. } => { + "response.audio_transcript.delta" + } + ResponsesAPIStreamEvent::ResponseAudioTranscriptDone { .. } => { + "response.audio_transcript.done" + } + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => { + "response.function_call_arguments.delta" + } + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => { + "response.function_call_arguments.done" + } + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCodeDelta { .. } => { + "response.code_interpreter_call.code.delta" + } + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCodeDone { .. } => { + "response.code_interpreter_call.code.done" + } + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallInProgress { .. } => { + "response.code_interpreter_call.in_progress" + } + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallInterpreting { .. } => { + "response.code_interpreter_call.interpreting" + } + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCompleted { .. } => { + "response.code_interpreter_call.completed" + } + ResponsesAPIStreamEvent::ResponseCustomToolCallInputDelta { .. } => { + "response.custom_tool_call.input.delta" + } + ResponsesAPIStreamEvent::ResponseCustomToolCallInputDone { .. } => { + "response.custom_tool_call.input.done" + } + ResponsesAPIStreamEvent::Error { .. } => "error", + ResponsesAPIStreamEvent::Done { .. } => "done", + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_response_output_text_delta_deserialization() { + let json = r#"{ + "type":"response.output_text.delta", + "sequence_number":811, + "item_id":"msg_0d87415661475591006924ce5465748190bdc8874257743b5c", + "output_index":1, + "content_index":0, + "delta":" first", + "logprobs":[], + "obfuscation":"sRhca4PA06" + }"#; + + let event: ResponsesAPIStreamEvent = + serde_json::from_str(json).expect("Failed to deserialize"); + + match event { + ResponsesAPIStreamEvent::ResponseOutputTextDelta { + item_id, + output_index, + content_index, + delta, + sequence_number, + logprobs, + obfuscation, + } => { + assert_eq!( + item_id, + "msg_0d87415661475591006924ce5465748190bdc8874257743b5c" + ); + assert_eq!(output_index, 1); + assert_eq!(content_index, 0); + assert_eq!(delta, " first"); + assert_eq!(sequence_number, 811); + assert_eq!(logprobs.len(), 0); + assert_eq!(obfuscation, Some("sRhca4PA06".to_string())); + } + _ => panic!("Expected ResponseOutputTextDelta event"), + } + } + + #[test] + fn test_response_output_text_done_deserialization() { + let json = r#"{ + "type":"response.output_text.done", + "sequence_number":818, + "item_id":"msg_0d87415661475591006924ce5465748190bdc8874257743b5c", + "output_index":1, + "content_index":0, + "text":"The otters linked paws and laughed.", + "logprobs":[] + }"#; + + let event: ResponsesAPIStreamEvent = + serde_json::from_str(json).expect("Failed to deserialize"); + + match event { + ResponsesAPIStreamEvent::ResponseOutputTextDone { + item_id, + output_index, + content_index, + text, + sequence_number, + logprobs, + } => { + assert_eq!( + item_id, + "msg_0d87415661475591006924ce5465748190bdc8874257743b5c" + ); + assert_eq!(output_index, 1); + assert_eq!(content_index, 0); + assert_eq!(text, "The otters linked paws and laughed."); + assert_eq!(sequence_number, 818); + assert_eq!(logprobs.len(), 0); + } + _ => panic!("Expected ResponseOutputTextDone event"), + } + } + + #[test] + fn test_response_completed_deserialization() { + // Simplified response.completed event + let json = r#"{ + "type":"response.completed", + "sequence_number":821, + "response":{ + "id":"resp_test123", + "object":"response", + "created_at":1764019793, + "status":"completed", + "background":false, + "error":null, + "incomplete_details":null, + "instructions":null, + "max_output_tokens":null, + "max_tool_calls":null, + "model":"o3-2025-04-16", + "output":[], + "output_text":null, + "usage":{ + "input_tokens":17, + "output_tokens":946, + "total_tokens":963 + }, + "parallel_tool_calls":true, + "conversation":null, + "previous_response_id":null, + "tools":[], + "tool_choice":"auto", + "temperature":1.0, + "top_p":1.0, + "metadata":{}, + "truncation":null, + "user":null, + "reasoning":null, + "store":true, + "text":null, + "audio":null, + "modalities":null, + "service_tier":"default", + "top_logprobs":0 + } + }"#; + + let event: ResponsesAPIStreamEvent = + serde_json::from_str(json).expect("Failed to deserialize"); + + match event { + ResponsesAPIStreamEvent::ResponseCompleted { + response, + sequence_number, + } => { + assert_eq!(response.id, "resp_test123"); + assert_eq!(sequence_number, 821); + assert_eq!(response.model, "o3-2025-04-16"); + } + _ => panic!("Expected ResponseCompleted event"), + } + } +} diff --git a/crates/hermesllm/src/apis/amazon_bedrock_binary_frame.rs b/crates/hermesllm/src/apis/streaming_shapes/amazon_bedrock_binary_frame.rs similarity index 66% rename from crates/hermesllm/src/apis/amazon_bedrock_binary_frame.rs rename to crates/hermesllm/src/apis/streaming_shapes/amazon_bedrock_binary_frame.rs index bacbad62..7f68bb26 100644 --- a/crates/hermesllm/src/apis/amazon_bedrock_binary_frame.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/amazon_bedrock_binary_frame.rs @@ -1,7 +1,6 @@ use aws_smithy_eventstream::frame::DecodedFrame; use aws_smithy_eventstream::frame::MessageFrameDecoder; use bytes::Buf; -use std::collections::HashSet; /// AWS Event Stream frame decoder wrapper pub struct BedrockBinaryFrameDecoder @@ -10,7 +9,6 @@ where { decoder: MessageFrameDecoder, buffer: B, - content_block_start_indices: HashSet, } impl BedrockBinaryFrameDecoder { @@ -20,7 +18,6 @@ impl BedrockBinaryFrameDecoder { Self { decoder: MessageFrameDecoder::new(), buffer, - content_block_start_indices: std::collections::HashSet::new(), } } } @@ -33,7 +30,6 @@ where Self { decoder: MessageFrameDecoder::new(), buffer, - content_block_start_indices: HashSet::new(), } } @@ -52,14 +48,4 @@ where pub fn has_remaining(&self) -> bool { self.buffer.has_remaining() } - - /// Check if a content_block_start event has been sent for the given index - pub fn has_content_block_start_been_sent(&self, index: i32) -> bool { - self.content_block_start_indices.contains(&index) - } - - /// Mark that a content_block_start event has been sent for the given index - pub fn set_content_block_start_sent(&mut self, index: i32) { - self.content_block_start_indices.insert(index); - } } diff --git a/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs new file mode 100644 index 00000000..818ee37d --- /dev/null +++ b/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs @@ -0,0 +1,507 @@ +use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait}; +use crate::apis::anthropic::MessagesStreamEvent; +use crate::providers::streaming_response::ProviderStreamResponseType; +use std::collections::HashSet; + +/// SSE Stream Buffer for Anthropic Messages API streaming. +/// +/// This buffer manages the wire format for Anthropic Messages API streaming, +/// handling the specific event sequencing requirements: +/// - MessageStart → ContentBlockStart → ContentBlockDelta(s) → ContentBlockStop → MessageDelta → MessageStop +/// +/// When converting from OpenAI to Anthropic format, this buffer injects the required +/// ContentBlockStart and ContentBlockStop events to maintain proper Anthropic protocol. +pub struct AnthropicMessagesStreamBuffer { + /// Buffered SSE events ready to be written to wire + buffered_events: Vec, + + /// Track if we've seen a message_start event + message_started: bool, + + /// Track content block indices that have received ContentBlockStart events + content_block_start_indices: HashSet, + + /// Track if we need to inject ContentBlockStop before message_delta + needs_content_block_stop: bool, + + /// Track if we've seen a MessageDelta (so we need to send MessageStop at the end) + seen_message_delta: bool, + + /// Model name to use when generating message_start events + model: Option, +} + +impl AnthropicMessagesStreamBuffer { + pub fn new() -> Self { + Self { + buffered_events: Vec::new(), + message_started: false, + content_block_start_indices: HashSet::new(), + needs_content_block_stop: false, + seen_message_delta: false, + model: None, + } + } + + /// Check if a content_block_start event has been sent for the given index + fn has_content_block_start_been_sent(&self, index: i32) -> bool { + self.content_block_start_indices.contains(&index) + } + + /// Mark that a content_block_start event has been sent for the given index + fn set_content_block_start_sent(&mut self, index: i32) { + self.content_block_start_indices.insert(index); + } + + /// Helper to create and format a ContentBlockStart SSE event + fn create_content_block_start_event() -> SseEvent { + let content_block_start = MessagesStreamEvent::ContentBlockStart { + index: 0, + content_block: crate::apis::anthropic::MessagesContentBlock::Text { + text: String::new(), + cache_control: None, + }, + }; + let sse_string: String = content_block_start.into(); + + SseEvent { + data: None, + event: Some("content_block_start".to_string()), + raw_line: sse_string.clone(), + sse_transformed_lines: sse_string, + provider_stream_response: None, + } + } + + /// Helper to create and format a MessageStart SSE event + fn create_message_start_event(model: &str) -> SseEvent { + let message_start = MessagesStreamEvent::MessageStart { + message: crate::apis::anthropic::MessagesStreamMessage { + id: format!("msg_{}", uuid::Uuid::new_v4().to_string().replace("-", "")), + obj_type: "message".to_string(), + role: crate::apis::anthropic::MessagesRole::Assistant, + content: vec![], + model: model.to_string(), + stop_reason: None, + stop_sequence: None, + usage: crate::apis::anthropic::MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }, + }; + let sse_string: String = message_start.into(); + + SseEvent { + data: None, + event: Some("message_start".to_string()), + raw_line: sse_string.clone(), + sse_transformed_lines: sse_string, + provider_stream_response: None, + } + } + + /// Helper to create and format a ContentBlockStop SSE event + fn create_content_block_stop_event() -> SseEvent { + let content_block_stop = MessagesStreamEvent::ContentBlockStop { index: 0 }; + let sse_string: String = content_block_stop.into(); + + SseEvent { + data: None, + event: Some("content_block_stop".to_string()), + raw_line: sse_string.clone(), + sse_transformed_lines: sse_string, + provider_stream_response: None, + } + } +} + +impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { + fn add_transformed_event(&mut self, event: SseEvent) { + // Skip ping messages + if event.should_skip() { + return; + } + + // FIRST: Try to extract model name from the raw event data before transformation + // The provider_stream_response has already been transformed to Anthropic format, + // so we need to extract the model from the original raw data if available + if self.model.is_none() { + if let Some(data) = &event.data { + // Try to parse as JSON and extract model field + if let Ok(json) = serde_json::from_str::(data) { + if let Some(model) = json.get("model").and_then(|m| m.as_str()) { + self.model = Some(model.to_string()); + } + } + } + } + + // Match directly on the provider response type to handle event processing + // We match on a reference first to determine the type, then move the event + match &event.provider_stream_response { + Some(ProviderStreamResponseType::MessagesStreamEvent(evt)) => { + match evt { + MessagesStreamEvent::MessageStart { .. } => { + // Add the message_start event + self.buffered_events.push(event); + self.message_started = true; + } + MessagesStreamEvent::ContentBlockStart { index, .. } => { + let index = *index as i32; + // Inject message_start if needed + if !self.message_started { + let model = self.model.as_deref().unwrap_or("unknown"); + let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model); + self.buffered_events.push(message_start); + self.message_started = true; + } + + // Add the content_block_start event (from tool calls or other sources) + self.buffered_events.push(event); + self.set_content_block_start_sent(index); + self.needs_content_block_stop = true; + } + MessagesStreamEvent::ContentBlockDelta { index, .. } => { + let index = *index as i32; + // Inject message_start if needed + if !self.message_started { + let model = self.model.as_deref().unwrap_or("unknown"); + let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model); + self.buffered_events.push(message_start); + self.message_started = true; + } + + // Check if ContentBlockStart was sent for this index + if !self.has_content_block_start_been_sent(index) { + // Inject ContentBlockStart before delta + let content_block_start = AnthropicMessagesStreamBuffer::create_content_block_start_event(); + self.buffered_events.push(content_block_start); + self.set_content_block_start_sent(index); + self.needs_content_block_stop = true; + } + + // Content deltas are between ContentBlockStart and ContentBlockStop + self.buffered_events.push(event); + } + MessagesStreamEvent::MessageDelta { usage, .. } => { + // Inject ContentBlockStop before message_delta + if self.needs_content_block_stop { + let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event(); + self.buffered_events.push(content_block_stop); + self.needs_content_block_stop = false; + } + + // Check if the last event was also a MessageDelta - if so, merge them + // This handles Bedrock's split of stop_reason (MessageStop) and usage (Metadata) + if let Some(last_event) = self.buffered_events.last_mut() { + if let Some(ProviderStreamResponseType::MessagesStreamEvent( + MessagesStreamEvent::MessageDelta { + usage: last_usage, + .. + } + )) = &mut last_event.provider_stream_response { + // Merge: take stop_reason from first, usage from second (if non-zero) + if usage.input_tokens > 0 || usage.output_tokens > 0 { + *last_usage = usage.clone(); + } + // Mark that we've seen MessageDelta (need to send MessageStop later) + self.seen_message_delta = true; + // Don't push the new event, we've merged it + return; + } + } + + // No previous MessageDelta to merge with, add this one + self.buffered_events.push(event); + self.seen_message_delta = true; + } + MessagesStreamEvent::ContentBlockStop { .. } => { + // ContentBlockStop received from upstream (e.g., Bedrock) + // Clear the flag so we don't inject another one + self.needs_content_block_stop = false; + self.buffered_events.push(event); + } + MessagesStreamEvent::MessageStop => { + // MessageStop received from upstream (e.g., OpenAI via [DONE]) + // Clear the flag so we don't inject another one + self.seen_message_delta = false; + self.buffered_events.push(event); + } + _ => { + // Other Anthropic event types (Ping, etc.), just accumulate + self.buffered_events.push(event); + } + } + } + _ => { + // Non-Anthropic events or events without provider_stream_response, just accumulate + self.buffered_events.push(event); + } + } + } + + fn into_bytes(&mut self) -> Vec { + // Convert all accumulated events to bytes and clear buffer + // NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta + // or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming. + + // Inject MessageStop after MessageDelta if we've seen one + // This completes the Anthropic Messages API event sequence + if self.seen_message_delta { + let message_stop = MessagesStreamEvent::MessageStop; + let sse_string: String = message_stop.into(); + let message_stop_event = SseEvent { + data: None, + event: Some("message_stop".to_string()), + raw_line: sse_string.clone(), + sse_transformed_lines: sse_string, + provider_stream_response: None, + }; + self.buffered_events.push(message_stop_event); + self.seen_message_delta = false; + } + + let mut buffer = Vec::new(); + for event in self.buffered_events.drain(..) { + let event_bytes: Vec = event.into(); + buffer.extend_from_slice(&event_bytes); + } + buffer + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + use crate::apis::streaming_shapes::sse::SseStreamIter; + + #[test] + fn test_openai_to_anthropic_complete_transformation() { + // OpenAI ChatCompletions input that will be transformed to Anthropic Messages API + let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]} + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + +data: [DONE]"#; + + println!("\n{}", "=".repeat(80)); + println!("TEST 1: OpenAI → Anthropic Messages API Complete Transformation"); + println!("{}", "=".repeat(80)); + println!("\nRAW INPUT (OpenAI ChatCompletions):"); + println!("{}", "-".repeat(80)); + println!("{}", raw_input); + + // Setup API configuration for transformation (client wants Anthropic, upstream is OpenAI) + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Parse events and apply transformation + let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap(); + let mut buffer = AnthropicMessagesStreamBuffer::new(); + + for raw_event in stream_iter { + let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + buffer.add_transformed_event(transformed_event); + } + + let output_bytes = buffer.into_bytes(); + let output = String::from_utf8_lossy(&output_bytes); + + println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):"); + println!("{}", "-".repeat(80)); + println!("{}", output); + + // Assertions + assert!(!output_bytes.is_empty(), "Should have output"); + assert!(output.contains("event: message_start"), "Should have message_start"); + assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)"); + + let delta_count = output.matches("event: content_block_delta").count(); + assert_eq!(delta_count, 2, "Should have exactly 2 content_block_delta events"); + + // Verify both pieces of content are present + assert!(output.contains("\"text\":\"Hello\""), "Should have first content delta 'Hello'"); + assert!(output.contains("\"text\":\" world\""), "Should have second content delta ' world'"); + + assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)"); + assert!(output.contains("event: message_delta"), "Should have message_delta"); + assert!(output.contains("event: message_stop"), "Should have message_stop"); + + println!("\nVALIDATION SUMMARY:"); + println!("{}", "-".repeat(80)); + println!("✓ Complete transformation: OpenAI ChatCompletions → Anthropic Messages API"); + println!("✓ Injected lifecycle events: message_start, content_block_start, content_block_stop"); + println!("✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)", delta_count); + println!("✓ Complete stream with message_stop"); + println!("✓ Proper Anthropic protocol sequencing\n"); + } + + #[test] + fn test_openai_to_anthropic_partial_transformation() { + // Partial OpenAI ChatCompletions stream - no [DONE] + let raw_input = r#"data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"The weather"},"finish_reason":null}]} + +data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" in San Francisco"},"finish_reason":null}]} + +data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" is"},"finish_reason":null}]}"#; + + println!("\n{}", "=".repeat(80)); + println!("TEST 2: OpenAI → Anthropic Partial Transformation (NO [DONE])"); + println!("{}", "=".repeat(80)); + println!("\nRAW INPUT (OpenAI ChatCompletions - NO [DONE]):"); + println!("{}", "-".repeat(80)); + println!("{}", raw_input); + + // Setup API configuration for transformation + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Parse and transform events + let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap(); + let mut buffer = AnthropicMessagesStreamBuffer::new(); + + for raw_event in stream_iter { + let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + buffer.add_transformed_event(transformed_event); + } + + let output_bytes = buffer.into_bytes(); + let output = String::from_utf8_lossy(&output_bytes); + + println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):"); + println!("{}", "-".repeat(80)); + println!("{}", output); + + // Assertions + assert!(!output_bytes.is_empty(), "Should have output"); + assert!(output.contains("event: message_start"), "Should have message_start"); + assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)"); + + let delta_count = output.matches("event: content_block_delta").count(); + assert_eq!(delta_count, 3, "Should have exactly 3 content_block_delta events"); + + // Verify all three pieces of content are present + assert!(output.contains("\"text\":\"The weather\""), "Should have first content delta"); + assert!(output.contains("\"text\":\" in San Francisco\""), "Should have second content delta"); + assert!(output.contains("\"text\":\" is\""), "Should have third content delta"); + + // For partial streams (no finish_reason, no [DONE]), we do NOT inject content_block_stop + // because the stream may continue. This is correct behavior - only inject lifecycle events + // when we have explicit signals from upstream (finish_reason, [DONE], etc.) + assert!(!output.contains("event: content_block_stop"), "Should NOT have content_block_stop for partial stream"); + + // Should NOT have completion events + assert!(!output.contains("event: message_delta"), "Should NOT have message_delta"); + assert!(!output.contains("event: message_stop"), "Should NOT have message_stop"); + + println!("\nVALIDATION SUMMARY:"); + println!("{}", "-".repeat(80)); + println!("✓ Partial transformation: OpenAI → Anthropic (stream interrupted)"); + println!("✓ Injected: message_start, content_block_start at beginning"); + println!("✓ Incremental deltas: {} events (ALL content preserved!)", delta_count); + println!("✓ NO completion events (partial stream, no [DONE])"); + println!("✓ Buffer maintains Anthropic protocol for active streams\n"); + } + + #[test] + fn test_openai_tool_calling_to_anthropic_transformation() { + // OpenAI ChatCompletions tool calling stream + let raw_input = r#"data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_2Uzw0AEZQeOex2CP2TKjcLKc","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"obfuscation":"uSpCcO"} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"24WSqt08jtf"} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"6CleV8twTxkKYg"} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"San"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Francisco"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"1XLz89l3v"} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":","}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"sh"} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" CA"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"obfuscation":"I"} + +data: [DONE]"#; + + println!("\n{}", "=".repeat(80)); + println!("TEST 3: OpenAI Tool Calling → Anthropic Messages API Transformation"); + println!("{}", "=".repeat(80)); + println!("\nRAW INPUT (OpenAI ChatCompletions with Tool Calls):"); + println!("{}", "-".repeat(80)); + println!("{}", raw_input); + + // Setup API configuration for transformation + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Parse and transform events + let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap(); + let mut buffer = AnthropicMessagesStreamBuffer::new(); + + for raw_event in stream_iter { + let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + buffer.add_transformed_event(transformed_event); + } + + let output_bytes = buffer.into_bytes(); + let output = String::from_utf8_lossy(&output_bytes); + + println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):"); + println!("{}", "-".repeat(80)); + println!("{}", output); + + // Assertions for tool calling transformation + assert!(!output_bytes.is_empty(), "Should have output"); + + // Should have lifecycle events (injected by buffer) + assert!(output.contains("event: message_start"), "Should have message_start (injected)"); + assert!(output.contains("event: content_block_start"), "Should have content_block_start"); + assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)"); + assert!(output.contains("event: message_delta"), "Should have message_delta"); + assert!(output.contains("event: message_stop"), "Should have message_stop"); + + // Should have tool_use content block + assert!(output.contains("\"type\":\"tool_use\""), "Should have tool_use type"); + assert!(output.contains("\"name\":\"get_weather\""), "Should have correct function name"); + assert!(output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""), "Should have correct tool call ID"); + + // Count input_json_delta events - should match the number of argument chunks + let delta_count = output.matches("event: content_block_delta").count(); + assert!(delta_count >= 8, "Should have at least 8 input_json_delta events"); + + // Verify argument deltas are present + assert!(output.contains("\"type\":\"input_json_delta\""), "Should have input_json_delta type"); + assert!(output.contains("\"partial_json\":"), "Should have partial_json field"); + + // Verify the accumulated arguments contain the location + assert!(output.contains("San"), "Arguments should contain 'San'"); + assert!(output.contains("Francisco"), "Arguments should contain 'Francisco'"); + assert!(output.contains("CA"), "Arguments should contain 'CA'"); + + // Verify stop reason is tool_use + assert!(output.contains("\"stop_reason\":\"tool_use\""), "Should have stop_reason as tool_use"); + + println!("\nVALIDATION SUMMARY:"); + println!("{}", "-".repeat(80)); + println!("✓ Complete tool calling transformation: OpenAI → Anthropic Messages API"); + println!("✓ Injected lifecycle: message_start, content_block_stop"); + println!("✓ Tool metadata: name='get_weather', id='call_2Uzw0AEZQeOex2CP2TKjcLKc'"); + println!("✓ Argument deltas: {} events", delta_count); + println!("✓ Complete JSON arguments: '{{\"location\":\"San Francisco, CA\"}}'"); + println!("✓ Stop reason: tool_use"); + println!("✓ Proper Anthropic tool_use protocol\n"); + } +} diff --git a/crates/hermesllm/src/apis/streaming_shapes/chat_completions_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/chat_completions_streaming_buffer.rs new file mode 100644 index 00000000..0243a5cd --- /dev/null +++ b/crates/hermesllm/src/apis/streaming_shapes/chat_completions_streaming_buffer.rs @@ -0,0 +1,39 @@ +use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait}; + +/// OpenAI Chat Completions SSE Stream Buffer for when client and upstream APIs match. +pub struct OpenAIChatCompletionsStreamBuffer { + /// Buffered SSE events ready to be written to wire + buffered_events: Vec, +} + +impl OpenAIChatCompletionsStreamBuffer { + pub fn new() -> Self { + Self { + buffered_events: Vec::new(), + } + } +} + +impl SseStreamBufferTrait for OpenAIChatCompletionsStreamBuffer { + fn add_transformed_event(&mut self, event: SseEvent) { + // Skip ping messages + if event.should_skip() { + return; + } + + // For OpenAI Chat Completions, events are already properly transformed + // Just accumulate them for later wire transmission + self.buffered_events.push(event); + } + + fn into_bytes(&mut self) -> Vec { + // No finalization needed for OpenAI Chat Completions + // The [DONE] marker is already handled by the transformation layer + let mut buffer = Vec::new(); + for event in self.buffered_events.drain(..) { + let event_bytes: Vec = event.into(); + buffer.extend_from_slice(&event_bytes); + } + buffer + } +} diff --git a/crates/hermesllm/src/apis/streaming_shapes/mod.rs b/crates/hermesllm/src/apis/streaming_shapes/mod.rs new file mode 100644 index 00000000..1ef7acc7 --- /dev/null +++ b/crates/hermesllm/src/apis/streaming_shapes/mod.rs @@ -0,0 +1,6 @@ +pub mod sse; +pub mod amazon_bedrock_binary_frame; +pub mod anthropic_streaming_buffer; +pub mod chat_completions_streaming_buffer; +pub mod passthrough_streaming_buffer; +pub mod responses_api_streaming_buffer; diff --git a/crates/hermesllm/src/apis/streaming_shapes/passthrough_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/passthrough_streaming_buffer.rs new file mode 100644 index 00000000..2ac2a688 --- /dev/null +++ b/crates/hermesllm/src/apis/streaming_shapes/passthrough_streaming_buffer.rs @@ -0,0 +1,95 @@ +use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait}; + +/// Passthrough SSE Stream Buffer for when client and upstream APIs match. +pub struct PassthroughStreamBuffer { + /// Buffered SSE events ready to be written to wire + buffered_events: Vec, +} + +impl PassthroughStreamBuffer { + pub fn new() -> Self { + Self { + buffered_events: Vec::new(), + } + } +} + +impl SseStreamBufferTrait for PassthroughStreamBuffer { + fn add_transformed_event(&mut self, event: SseEvent) { + // Skip ping messages + if event.should_skip() { + return; + } + + // Skip events with empty transformed lines (e.g., suppressed event-only lines) + if event.sse_transformed_lines.is_empty() { + return; + } + + // Just accumulate events as-is + self.buffered_events.push(event); + } + + fn into_bytes(&mut self) -> Vec { + // No finalization needed for passthrough - just convert accumulated events to bytes + let mut buffer = Vec::new(); + for event in self.buffered_events.drain(..) { + let event_bytes: Vec = event.into(); + buffer.extend_from_slice(&event_bytes); + } + buffer + } +} + +#[cfg(test)] +mod tests { + use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer; + use crate::apis::streaming_shapes::sse::{SseStreamIter, SseStreamBufferTrait}; + + #[test] + fn test_chat_completions_passthrough_buffer() { + let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]} + + data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]} + + data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}]} + + data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]} + + data: [DONE]"#; + + println!("\n{}", "=".repeat(80)); + println!("TEST 1: ChatCompletions Passthrough Buffer"); + println!("{}", "=".repeat(80)); + println!("\nRAW INPUT (ChatCompletions):"); + println!("{}", "-".repeat(80)); + println!("{}", raw_input); + + // Parse and process through buffer + let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap(); + let mut buffer = PassthroughStreamBuffer::new(); + + for event in stream_iter { + buffer.add_transformed_event(event); + } + + let output_bytes = buffer.into_bytes(); + let output = String::from_utf8_lossy(&output_bytes); + + println!("\nTRANSFORMED OUTPUT (ChatCompletions - Passthrough):"); + println!("{}", "-".repeat(80)); + println!("{}", output); + + // Assertions + assert!(!output_bytes.is_empty()); + assert!(output.contains("chatcmpl-123")); + assert!(output.contains("[DONE]")); + assert_eq!(raw_input.trim(), output.trim(), "Passthrough should preserve input"); + + println!("\nVALIDATION SUMMARY:"); + println!("{}", "-".repeat(80)); + println!("✓ Passthrough buffer: input = output (no transformation)"); + println!("✓ All events preserved including [DONE]"); + println!("✓ Function calling events preserved\n"); + } +} diff --git a/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs new file mode 100644 index 00000000..84854af3 --- /dev/null +++ b/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs @@ -0,0 +1,600 @@ +use std::collections::HashMap; +use log::debug; +use crate::apis::openai_responses::{ + ResponsesAPIStreamEvent, ResponsesAPIResponse, OutputItem, OutputItemStatus, + ResponseStatus, TextConfig, TextFormat, Reasoning, +}; +use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait}; + +/// Helper to convert ResponseAPIStreamEvent to SseEvent +fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent { + let event_type = match &event { + ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created", + ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress", + ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed", + ResponsesAPIStreamEvent::ResponseOutputItemAdded { .. } => "response.output_item.added", + ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done", + ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta", + ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done", + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => "response.function_call_arguments.delta", + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => "response.function_call_arguments.done", + unknown => { + debug!("Unknown ResponsesAPIStreamEvent type encountered: {:?}", unknown); + "unknown" + } + }; + + let json_data = match serde_json::to_string(&event) { + Ok(data) => data, + Err(e) => { + debug!("Error serializing ResponsesAPIStreamEvent to JSON: {}", e); + String::new() + } + }; + let wire_format: String = event.into(); + + SseEvent { + data: Some(json_data), + event: Some(event_type.to_string()), + raw_line: wire_format.clone(), + sse_transformed_lines: wire_format, + provider_stream_response: None, + } +} + +/// SSE Stream Buffer for ResponsesAPIStreamEvent with full lifecycle management. +/// +/// This buffer manages the wire format for v1/responses streaming, handling +/// delta events and emitting complete lifecycle events. +/// +pub struct ResponsesAPIStreamBuffer { + /// Sequence number for events + sequence_number: i32, + + /// Track item IDs by output index + item_ids: HashMap, + + /// Response metadata + response_id: Option, + model: Option, + created_at: Option, + + /// Lifecycle state flags + created_emitted: bool, + in_progress_emitted: bool, + + /// Track which output items we've added + output_items_added: HashMap, // output_index -> item_id + + /// Accumulated content by item_id + text_content: HashMap, + function_arguments: HashMap, + + /// Tool call metadata by output_index + tool_call_metadata: HashMap, // output_index -> (call_id, name) + + /// Final completed response (for logging/tracing/persistence) + completed_response: Option, + + /// Buffered SSE events ready to be written to wire + buffered_events: Vec, +} + +impl ResponsesAPIStreamBuffer { + pub fn new() -> Self { + Self { + sequence_number: 0, + item_ids: HashMap::new(), + response_id: None, + model: None, + created_at: None, + created_emitted: false, + in_progress_emitted: false, + output_items_added: HashMap::new(), + text_content: HashMap::new(), + function_arguments: HashMap::new(), + tool_call_metadata: HashMap::new(), + completed_response: None, + buffered_events: Vec::new(), + } + } + + fn next_sequence_number(&mut self) -> i32 { + let seq = self.sequence_number; + self.sequence_number += 1; + seq + } + + fn generate_item_id(prefix: &str) -> String { + format!("{}_{}", prefix, uuid::Uuid::new_v4().to_string().replace("-", "")) + } + + fn get_or_create_item_id(&mut self, output_index: i32, prefix: &str) -> String { + if let Some(id) = self.item_ids.get(&output_index) { + return id.clone(); + } + let id = ResponsesAPIStreamBuffer::generate_item_id(prefix); + self.item_ids.insert(output_index, id.clone()); + id + } + + /// Create response.created event + fn create_response_created_event(&mut self) -> SseEvent { + let response = self.build_response(ResponseStatus::InProgress); + let event = ResponsesAPIStreamEvent::ResponseCreated { + response, + sequence_number: self.next_sequence_number(), + }; + event_to_sse(event) + } + + /// Create response.in_progress event + fn create_response_in_progress_event(&mut self) -> SseEvent { + let response = self.build_response(ResponseStatus::InProgress); + let event = ResponsesAPIStreamEvent::ResponseInProgress { + response, + sequence_number: self.next_sequence_number(), + }; + event_to_sse(event) + } + + /// Create output_item.added event for text + fn create_output_item_added_event(&mut self, output_index: i32, item_id: &str) -> SseEvent { + let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded { + output_index, + item: OutputItem::Message { + id: item_id.to_string(), + status: OutputItemStatus::InProgress, + role: "assistant".to_string(), + content: vec![], + }, + sequence_number: self.next_sequence_number(), + }; + event_to_sse(event) + } + + /// Create output_item.added event for tool call + fn create_tool_call_added_event(&mut self, output_index: i32, item_id: &str, call_id: &str, name: &str) -> SseEvent { + let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded { + output_index, + item: OutputItem::FunctionCall { + id: item_id.to_string(), + status: OutputItemStatus::InProgress, + call_id: call_id.to_string(), + name: Some(name.to_string()), + arguments: Some(String::new()), + }, + sequence_number: self.next_sequence_number(), + }; + event_to_sse(event) + } + + /// Build the base response object with current state + fn build_response(&self, status: ResponseStatus) -> ResponsesAPIResponse { + ResponsesAPIResponse { + id: self.response_id.clone().unwrap_or_default(), + object: "response".to_string(), + created_at: self.created_at.unwrap_or(0), + status, + error: None, + incomplete_details: None, + instructions: None, + model: self.model.clone().unwrap_or_else(|| "unknown".to_string()), + output: vec![], + usage: None, + parallel_tool_calls: true, + conversation: None, + previous_response_id: None, + tools: vec![], + tool_choice: "auto".to_string(), + temperature: 1.0, + top_p: 1.0, + metadata: HashMap::new(), + truncation: Some("disabled".to_string()), + max_output_tokens: None, + reasoning: Some(Reasoning { + effort: None, + summary: None, + }), + store: Some(true), + text: Some(TextConfig { + format: TextFormat::Text, + }), + audio: None, + modalities: None, + service_tier: Some("auto".to_string()), + background: Some(false), + top_logprobs: Some(0), + max_tool_calls: None, + } + } + + /// Get the completed response after finalization (for logging/tracing/persistence) + pub fn get_completed_response(&self) -> Option<&ResponsesAPIResponse> { + self.completed_response.as_ref() + } + + /// Finalize the response by emitting all *.done events and response.completed. + /// Call this when the stream is complete (after seeing [DONE] or end_of_stream). + pub fn finalize(&mut self) { + let mut events = Vec::new(); + + // Emit done events for all accumulated content + + // Text content done events + let text_items: Vec<_> = self.text_content.iter().map(|(id, content)| (id.clone(), content.clone())).collect(); + for (item_id, content) in text_items { + let output_index = self.output_items_added.iter() + .find(|(_, id)| **id == item_id) + .map(|(idx, _)| *idx) + .unwrap_or(0); + + let seq1 = self.next_sequence_number(); + let text_done_event = ResponsesAPIStreamEvent::ResponseOutputTextDone { + item_id: item_id.clone(), + output_index, + content_index: 0, + text: content.clone(), + logprobs: vec![], + sequence_number: seq1, + }; + events.push(event_to_sse(text_done_event)); + + let seq2 = self.next_sequence_number(); + let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone { + output_index, + item: OutputItem::Message { + id: item_id.clone(), + status: OutputItemStatus::Completed, + role: "assistant".to_string(), + content: vec![], + }, + sequence_number: seq2, + }; + events.push(event_to_sse(item_done_event)); + } + + // Function call done events + let func_items: Vec<_> = self.function_arguments.iter().map(|(id, args)| (id.clone(), args.clone())).collect(); + for (item_id, arguments) in func_items { + let output_index = self.output_items_added.iter() + .find(|(_, id)| **id == item_id) + .map(|(idx, _)| *idx) + .unwrap_or(0); + + let seq1 = self.next_sequence_number(); + let args_done_event = ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { + output_index, + item_id: item_id.clone(), + arguments: arguments.clone(), + sequence_number: seq1, + }; + events.push(event_to_sse(args_done_event)); + + let (call_id, name) = self.tool_call_metadata.get(&output_index) + .cloned() + .unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); + + let seq2 = self.next_sequence_number(); + let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone { + output_index, + item: OutputItem::FunctionCall { + id: item_id.clone(), + status: OutputItemStatus::Completed, + call_id, + name: Some(name), + arguments: Some(arguments.clone()), + }, + sequence_number: seq2, + }; + events.push(event_to_sse(item_done_event)); + } + + // Build final response + let mut output_items = Vec::new(); + + // Add tool calls to output + for (item_id, arguments) in &self.function_arguments { + let output_index = self.output_items_added.iter() + .find(|(_, id)| *id == item_id) + .map(|(idx, _)| *idx) + .unwrap_or(0); + + let (call_id, name) = self.tool_call_metadata.get(&output_index) + .cloned() + .unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); + + output_items.push(OutputItem::FunctionCall { + id: item_id.clone(), + status: OutputItemStatus::Completed, + call_id, + name: Some(name), + arguments: Some(arguments.clone()), + }); + } + + let mut final_response = self.build_response(ResponseStatus::Completed); + final_response.output = output_items; + + // Store completed response + self.completed_response = Some(final_response.clone()); + + // Emit response.completed + let seq_final = self.next_sequence_number(); + let completed_event = ResponsesAPIStreamEvent::ResponseCompleted { + response: final_response, + sequence_number: seq_final, + }; + events.push(event_to_sse(completed_event)); + + // Add all finalization events to the buffer + self.buffered_events.extend(events); + } +} + +impl SseStreamBufferTrait for ResponsesAPIStreamBuffer { + fn add_transformed_event(&mut self, event: SseEvent) { + // Skip ping messages + if event.should_skip() { + return; + } + + // Handle [DONE] marker - trigger finalization + if event.is_done() { + self.finalize(); + return; + } + + // Extract the ResponseAPIStreamEvent from the SseEvent's provider_stream_response + let provider_response = match event.provider_stream_response.as_ref() { + Some(response) => response, + None => { + eprintln!("Warning: Event missing provider_stream_response"); + return; + } + }; + + // Extract ResponseAPIStreamEvent from the enum + let stream_event = match provider_response { + crate::providers::streaming_response::ProviderStreamResponseType::ResponseAPIStreamEvent(evt) => evt, + _ => { + eprintln!("Warning: Expected ResponseAPIStreamEvent in provider_stream_response"); + return; + } + }; + + let mut events = Vec::new(); + + // Emit lifecycle events if not yet emitted + if !self.created_emitted { + // Initialize metadata from first event if needed + if self.response_id.is_none() { + self.response_id = Some(format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", ""))); + self.created_at = Some(std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64); + self.model = Some("unknown".to_string()); // Will be set by caller if available + } + + events.push(self.create_response_created_event()); + self.created_emitted = true; + } + + if !self.in_progress_emitted { + events.push(self.create_response_in_progress_event()); + self.in_progress_emitted = true; + } + + // Process the delta event + match stream_event { + ResponsesAPIStreamEvent::ResponseOutputTextDelta { output_index, delta, .. } => { + let item_id = self.get_or_create_item_id(*output_index, "msg"); + + // Emit output_item.added if this is the first time we see this output index + if !self.output_items_added.contains_key(output_index) { + self.output_items_added.insert(*output_index, item_id.clone()); + events.push(self.create_output_item_added_event(*output_index, &item_id)); + } + + // Accumulate text content + self.text_content.entry(item_id.clone()) + .and_modify(|content| content.push_str(delta)) + .or_insert_with(|| delta.clone()); + + // Emit text delta with filled-in item_id and sequence_number + let mut delta_event = stream_event.clone(); + if let ResponsesAPIStreamEvent::ResponseOutputTextDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event { + *id = item_id; + *seq = self.next_sequence_number(); + } + events.push(event_to_sse(delta_event)); + } + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { output_index, delta, call_id, name, .. } => { + let item_id = self.get_or_create_item_id(*output_index, "fc"); + + // Store metadata if provided (from initial tool call event) + if let (Some(cid), Some(n)) = (call_id, name) { + self.tool_call_metadata.insert(*output_index, (cid.clone(), n.clone())); + } + + // Emit output_item.added if this is the first time we see this tool call + if !self.output_items_added.contains_key(output_index) { + self.output_items_added.insert(*output_index, item_id.clone()); + + // For tool calls, we need call_id and name from metadata + // These should now be populated from the event itself + let (call_id, name) = self.tool_call_metadata.get(output_index) + .cloned() + .unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); + + events.push(self.create_tool_call_added_event(*output_index, &item_id, &call_id, &name)); + } + + // Accumulate function arguments + self.function_arguments.entry(item_id.clone()) + .and_modify(|args| args.push_str(delta)) + .or_insert_with(|| delta.clone()); + + // Emit function call arguments delta with filled-in item_id and sequence_number + let mut delta_event = stream_event.clone(); + if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event { + *id = item_id; + *seq = self.next_sequence_number(); + } + events.push(event_to_sse(delta_event)); + } + _ => { + // For other event types, just pass through with sequence number + let other_event = stream_event.clone(); + // TODO: Add sequence number to other event types if needed + events.push(event_to_sse(other_event)); + } + } + + // Store all generated events in the buffer + self.buffered_events.extend(events); + } + + + fn into_bytes(&mut self) -> Vec { + // For Responses API, we need special handling: + // - Most events are already in buffered_events from add_transformed_event + // - We should NOT finalize here - finalization happens when we detect [DONE] or end of stream + // - Just flush the accumulated events and clear the buffer + + // Convert all accumulated events to bytes and clear buffer + let mut buffer = Vec::new(); + for event in self.buffered_events.drain(..) { + let event_bytes: Vec = event.into(); + buffer.extend_from_slice(&event_bytes); + } + buffer + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; + use crate::apis::openai::OpenAIApi; + use crate::apis::streaming_shapes::sse::SseStreamIter; + + #[test] + fn test_chat_completions_to_responses_api_transformation() { + // ChatCompletions input that will be transformed to ResponsesAPI + let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} + + data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]} + + data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + + data: [DONE]"#; + + println!("\n{}", "=".repeat(80)); + println!("TEST 2: ChatCompletions → ResponsesAPI Transformation (with [DONE])"); + println!("{}", "=".repeat(80)); + println!("\nRAW INPUT (ChatCompletions):"); + println!("{}", "-".repeat(80)); + println!("{}", raw_input); + + // Setup API configuration for transformation + let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Parse events and apply transformation + let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap(); + let mut buffer = ResponsesAPIStreamBuffer::new(); + + for raw_event in stream_iter { + // Transform the event using the client/upstream APIs + let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + buffer.add_transformed_event(transformed_event); + } + + let output_bytes = buffer.into_bytes(); + let output = String::from_utf8_lossy(&output_bytes); + + println!("\nTRANSFORMED OUTPUT (ResponsesAPI):"); + println!("{}", "-".repeat(80)); + println!("{}", output); + + // Assertions + assert!(!output_bytes.is_empty(), "Should have output"); + assert!(output.contains("response.created"), "Should have response.created"); + assert!(output.contains("response.in_progress"), "Should have response.in_progress"); + assert!(output.contains("response.output_item.added"), "Should have output_item.added"); + assert!(output.contains("response.output_text.delta"), "Should have text deltas"); + assert!(output.contains("response.output_text.done"), "Should have text.done"); + assert!(output.contains("response.output_item.done"), "Should have output_item.done"); + assert!(output.contains("response.completed"), "Should have response.completed"); + + println!("\nVALIDATION SUMMARY:"); + println!("{}", "-".repeat(80)); + println!("✓ Lifecycle events: response.created, response.in_progress, response.completed"); + println!("✓ Output item lifecycle: output_item.added, output_item.done"); + println!("✓ Text streaming: output_text.delta (2 deltas), output_text.done"); + println!("✓ Complete transformation with finalization ([DONE] processed)\n"); + } + + #[test] + fn test_partial_streaming_incremental_output() { + let raw_input = r#"data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_mD5ggLKk3SMKGPFqFdcpKg6q","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"obfuscation":"PCFrpy"} + + data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""} + + data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"TC58A3QEIx8"} + + data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"PK4oFzlVlGTUP5"}"#; + + println!("\n{}", "=".repeat(80)); + println!("TEST 3: Partial Streaming - Function Calling (NO [DONE])"); + println!("{}", "=".repeat(80)); + println!("\nRAW INPUT (ChatCompletions - NO [DONE]):"); + println!("{}", "-".repeat(80)); + println!("{}", raw_input); + + // Setup API configuration for transformation + let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform all events + let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap(); + let mut buffer = ResponsesAPIStreamBuffer::new(); + + for raw_event in stream_iter { + let transformed = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + buffer.add_transformed_event(transformed); + } + + let output_bytes = buffer.into_bytes(); + let output = String::from_utf8_lossy(&output_bytes); + + println!("\nTRANSFORMED OUTPUT (ResponsesAPI):"); + println!("{}", "-".repeat(80)); + println!("{}", output); + + // Assertions + assert!(output.contains("response.created"), "Should have response.created"); + assert!(output.contains("response.in_progress"), "Should have response.in_progress"); + assert!(output.contains("response.output_item.added"), "Should have output_item.added"); + assert!(output.contains("\"type\":\"function_call\""), "Should be function_call type"); + assert!(output.contains("\"name\":\"get_weather\""), "Should have function name"); + assert!(output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""), "Should have correct call_id"); + + let delta_count = output.matches("event: response.function_call_arguments.delta").count(); + assert_eq!(delta_count, 4, "Should have 4 delta events"); + + assert!(!output.contains("response.function_call_arguments.done"), "Should NOT have arguments.done"); + assert!(!output.contains("response.output_item.done"), "Should NOT have output_item.done"); + assert!(!output.contains("response.completed"), "Should NOT have response.completed"); + + println!("\nVALIDATION SUMMARY:"); + println!("{}", "-".repeat(80)); + println!("✓ Lifecycle events: response.created, response.in_progress"); + println!("✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'"); + println!("✓ Incremental deltas: 4 events (1 initial + 3 argument chunks)"); + println!("✓ NO completion events (partial stream, no [DONE])"); + println!("✓ Arguments accumulated: '{{\"location\":\"'\n"); + } +} diff --git a/crates/hermesllm/src/apis/sse.rs b/crates/hermesllm/src/apis/streaming_shapes/sse.rs similarity index 52% rename from crates/hermesllm/src/apis/sse.rs rename to crates/hermesllm/src/apis/streaming_shapes/sse.rs index b8a9b492..17c6873a 100644 --- a/crates/hermesllm/src/apis/sse.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/sse.rs @@ -1,10 +1,73 @@ -use crate::providers::response::ProviderStreamResponse; -use crate::providers::response::ProviderStreamResponseType; +use crate::providers::streaming_response::ProviderStreamResponse; +use crate::providers::streaming_response::ProviderStreamResponseType; +use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer; +use crate::apis::streaming_shapes::anthropic_streaming_buffer::AnthropicMessagesStreamBuffer; +use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer; +use crate::apis::streaming_shapes::responses_api_streaming_buffer::ResponsesAPIStreamBuffer; use serde::{Deserialize, Serialize}; use std::error::Error; use std::fmt; use std::str::FromStr; +/// Trait defining the interface for SSE stream buffers. +/// +/// This trait is implemented by both the enum `SseStreamBuffer` (for zero-cost dispatch) +/// and individual buffer implementations (for direct use). +/// +pub trait SseStreamBufferTrait: Send + Sync { + /// Add a transformed SSE event to the buffer. + /// + /// The buffer may inject additional events as needed based on internal state. + /// For example, Anthropic buffers inject ContentBlockStart before the first ContentBlockDelta. + /// + /// All events (original + injected) are accumulated internally for the next `into_bytes()` call. + /// + /// # Arguments + /// * `event` - A transformed SSE event to accumulate + fn add_transformed_event(&mut self, event: SseEvent); + + /// Get bytes for all accumulated events since the last call. + /// + /// This method: + /// - Converts all buffered events to wire format bytes + /// - Clears the internal event buffer + /// - Preserves state for subsequent `add_transformed_event()` calls + /// + /// Call this after processing each chunk of upstream events to get bytes for immediate transmission. + /// + /// # Returns + /// Bytes ready for wire transmission (may be empty if no events were accumulated) + fn into_bytes(&mut self) -> Vec; +} + +/// Unified SSE Stream Buffer enum that provides a zero-cost abstraction +pub enum SseStreamBuffer { + Passthrough(PassthroughStreamBuffer), + OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer), + AnthropicMessages(AnthropicMessagesStreamBuffer), + OpenAIResponses(ResponsesAPIStreamBuffer), +} + +impl SseStreamBufferTrait for SseStreamBuffer { + fn add_transformed_event(&mut self, event: SseEvent) { + match self { + Self::Passthrough(buffer) => buffer.add_transformed_event(event), + Self::OpenAIChatCompletions(buffer) => buffer.add_transformed_event(event), + Self::AnthropicMessages(buffer) => buffer.add_transformed_event(event), + Self::OpenAIResponses(buffer) => buffer.add_transformed_event(event), + } + } + + fn into_bytes(&mut self) -> Vec { + match self { + Self::Passthrough(buffer) => buffer.into_bytes(), + Self::OpenAIChatCompletions(buffer) => buffer.into_bytes(), + Self::AnthropicMessages(buffer) => buffer.into_bytes(), + Self::OpenAIResponses(buffer) => buffer.into_bytes(), + } + } +} + // ============================================================================ // SSE EVENT CONTAINER // ============================================================================ @@ -22,16 +85,31 @@ pub struct SseEvent { pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n" #[serde(skip_serializing, skip_deserializing)] - pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n" + pub sse_transformed_lines: String, // The complete line as received including "data: " prefix and "\n\n" #[serde(skip_serializing, skip_deserializing)] pub provider_stream_response: Option, // Parsed provider stream response object } impl SseEvent { + /// Create an SseEvent from a ProviderStreamResponseType + /// This is useful for binary frame formats (like Bedrock) that need to be converted to SSE + pub fn from_provider_response(response: ProviderStreamResponseType) -> Self { + // Convert the provider response to SSE format string + let sse_string: String = response.clone().into(); + + SseEvent { + data: None, // Data is embedded in sse_transformed_lines + event: None, // Event type is embedded in sse_transformed_lines + raw_line: sse_string.clone(), + sse_transformed_lines: sse_string, + provider_stream_response: Some(response), + } + } + /// Check if this event represents the end of the stream pub fn is_done(&self) -> bool { - self.data == Some("[DONE]".into()) + self.data == Some("[DONE]".into()) || self.event == Some("message_stop".into()) } /// Check if this event should be skipped during processing @@ -61,23 +139,35 @@ impl FromStr for SseEvent { type Err = SseParseError; fn from_str(line: &str) -> Result { - if line.starts_with("data: ") { - let data: String = line[6..].to_string(); // Remove "data: " prefix - if data.is_empty() { + // Trim leading/trailing whitespace for parsing + let trimmed_line = line.trim(); + + // Skip empty or whitespace-only lines (SSE event separators) + if trimmed_line.is_empty() { + return Err(SseParseError { + message: "Empty line (SSE event separator)".to_string(), + }); + } + + if trimmed_line.starts_with("data: ") { + let data: String = trimmed_line[6..].to_string(); // Remove "data: " prefix + // Allow empty data content after "data: " prefix + // This handles cases like "data: " followed by newline + if data.trim().is_empty() { return Err(SseParseError { - message: "Empty data field is not a valid SSE event".to_string(), + message: "Empty data field after 'data: ' prefix".to_string(), }); } Ok(SseEvent { data: Some(data), event: None, raw_line: line.to_string(), - sse_transform_buffer: line.to_string(), + // Preserve original line format for passthrough, use trimmed for transformations + sse_transformed_lines: line.to_string(), provider_stream_response: None, }) - } else if line.starts_with("event: ") { - //used by Anthropic - let event_type = line[7..].to_string(); + } else if trimmed_line.starts_with("event: ") { + let event_type = trimmed_line[7..].to_string(); if event_type.is_empty() { return Err(SseParseError { message: "Empty event field is not a valid SSE event".to_string(), @@ -87,12 +177,13 @@ impl FromStr for SseEvent { data: None, event: Some(event_type), raw_line: line.to_string(), - sse_transform_buffer: line.to_string(), + // Preserve original line format for passthrough, use trimmed for transformations + sse_transformed_lines: line.to_string(), provider_stream_response: None, }) } else { Err(SseParseError { - message: format!("Line does not start with 'data: ' or 'event: ': {}", line), + message: format!("Line does not start with 'data: ' or 'event: ': {}", trimmed_line), }) } } @@ -100,14 +191,14 @@ impl FromStr for SseEvent { impl fmt::Display for SseEvent { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.sse_transform_buffer) + write!(f, "{}", self.sse_transformed_lines) } } // Into implementation to convert SseEvent to bytes for response buffer impl Into> for SseEvent { fn into(self) -> Vec { - format!("{}\n\n", self.sse_transform_buffer).into_bytes() + format!("{}\n\n", self.sse_transformed_lines).into_bytes() } } diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index e0ad47d3..09ab262d 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -4,9 +4,10 @@ use std::fmt; /// Unified enum representing all supported API endpoints across providers #[derive(Debug, Clone, PartialEq)] -pub enum SupportedAPIs { +pub enum SupportedAPIsFromClient { OpenAIChatCompletions(OpenAIApi), AnthropicMessagesAPI(AnthropicApi), + OpenAIResponsesAPI(OpenAIApi), } #[derive(Debug, Clone, PartialEq)] @@ -15,17 +16,21 @@ pub enum SupportedUpstreamAPIs { AnthropicMessagesAPI(AnthropicApi), AmazonBedrockConverse(AmazonBedrockApi), AmazonBedrockConverseStream(AmazonBedrockApi), + OpenAIResponsesAPI(OpenAIApi), } -impl fmt::Display for SupportedAPIs { +impl fmt::Display for SupportedAPIsFromClient { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - SupportedAPIs::OpenAIChatCompletions(api) => { + SupportedAPIsFromClient::OpenAIChatCompletions(api) => { write!(f, "OpenAI ({})", api.endpoint()) } - SupportedAPIs::AnthropicMessagesAPI(api) => { + SupportedAPIsFromClient::AnthropicMessagesAPI(api) => { write!(f, "Anthropic AI ({})", api.endpoint()) } + SupportedAPIsFromClient::OpenAIResponsesAPI(api) => { + write!(f, "OpenAI Responses ({})", api.endpoint()) + } } } } @@ -45,19 +50,27 @@ impl fmt::Display for SupportedUpstreamAPIs { SupportedUpstreamAPIs::AmazonBedrockConverseStream(api) => { write!(f, "Amazon Bedrock ({})", api.endpoint()) } + SupportedUpstreamAPIs::OpenAIResponsesAPI(api) => { + write!(f, "OpenAI Responses ({})", api.endpoint()) + } } } } -impl SupportedAPIs { +impl SupportedAPIsFromClient { /// Create a SupportedApi from an endpoint path pub fn from_endpoint(endpoint: &str) -> Option { if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) { - return Some(SupportedAPIs::OpenAIChatCompletions(openai_api)); + // Check if this is the Responses API endpoint + if openai_api == OpenAIApi::Responses { + return Some(SupportedAPIsFromClient::OpenAIResponsesAPI(openai_api)); + } + // Otherwise it's ChatCompletions + return Some(SupportedAPIsFromClient::OpenAIChatCompletions(openai_api)); } if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) { - return Some(SupportedAPIs::AnthropicMessagesAPI(anthropic_api)); + return Some(SupportedAPIsFromClient::AnthropicMessagesAPI(anthropic_api)); } None @@ -66,8 +79,9 @@ impl SupportedAPIs { /// Get the endpoint path for this API pub fn endpoint(&self) -> &'static str { match self { - SupportedAPIs::OpenAIChatCompletions(api) => api.endpoint(), - SupportedAPIs::AnthropicMessagesAPI(api) => api.endpoint(), + SupportedAPIsFromClient::OpenAIChatCompletions(api) => api.endpoint(), + SupportedAPIsFromClient::AnthropicMessagesAPI(api) => api.endpoint(), + SupportedAPIsFromClient::OpenAIResponsesAPI(api) => api.endpoint(), } } @@ -94,8 +108,62 @@ impl SupportedAPIs { } }; + // Helper function to route based on provider with a specific endpoint suffix + let route_by_provider = |endpoint_suffix: &str| -> String { + match provider_id { + ProviderId::Groq => { + if request_path.starts_with("/v1/") { + build_endpoint("/openai", request_path) + } else { + build_endpoint("/v1", endpoint_suffix) + } + } + ProviderId::Zhipu => { + if request_path.starts_with("/v1/") { + build_endpoint("/api/paas/v4", endpoint_suffix) + } else { + build_endpoint("/v1", endpoint_suffix) + } + } + ProviderId::Qwen => { + if request_path.starts_with("/v1/") { + build_endpoint("/compatible-mode/v1", endpoint_suffix) + } else { + build_endpoint("/v1", endpoint_suffix) + } + } + ProviderId::AzureOpenAI => { + if request_path.starts_with("/v1/") { + let suffix = endpoint_suffix.trim_start_matches('/'); + build_endpoint("/openai/deployments", &format!("/{}/{}?api-version=2025-01-01-preview", model_id, suffix)) + } else { + build_endpoint("/v1", endpoint_suffix) + } + } + ProviderId::Gemini => { + if request_path.starts_with("/v1/") { + build_endpoint("/v1beta/openai", endpoint_suffix) + } else { + build_endpoint("/v1", endpoint_suffix) + } + } + ProviderId::AmazonBedrock => { + if request_path.starts_with("/v1/") { + if !is_streaming { + build_endpoint("", &format!("/model/{}/converse", model_id)) + } else { + build_endpoint("", &format!("/model/{}/converse-stream", model_id)) + } + } else { + build_endpoint("/v1", endpoint_suffix) + } + } + _ => build_endpoint("/v1", endpoint_suffix), + } + }; + match self { - SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id { + SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id { ProviderId::Anthropic => build_endpoint("/v1", "/messages"), ProviderId::AmazonBedrock => { if request_path.starts_with("/v1/") && !is_streaming { @@ -108,55 +176,19 @@ impl SupportedAPIs { } _ => build_endpoint("/v1", "/chat/completions"), }, - _ => match provider_id { - ProviderId::Groq => { - if request_path.starts_with("/v1/") { - build_endpoint("/openai", request_path) - } else { - build_endpoint("/v1", "/chat/completions") - } + SupportedAPIsFromClient::OpenAIResponsesAPI(_) => { + // For Responses API, check if provider supports it, otherwise translate to chat/completions + match provider_id { + // OpenAI and compatible providers that support /v1/responses + ProviderId::OpenAI => route_by_provider("/responses"), + // All other providers: translate to /chat/completions + _ => route_by_provider("/chat/completions"), } - ProviderId::Zhipu => { - if request_path.starts_with("/v1/") { - build_endpoint("/api/paas/v4", "/chat/completions") - } else { - build_endpoint("/v1", "/chat/completions") - } - } - ProviderId::Qwen => { - if request_path.starts_with("/v1/") { - build_endpoint("/compatible-mode/v1", "/chat/completions") - } else { - build_endpoint("/v1", "/chat/completions") - } - } - ProviderId::AzureOpenAI => { - if request_path.starts_with("/v1/") { - build_endpoint("/openai/deployments", &format!("/{}/chat/completions?api-version=2025-01-01-preview", model_id)) - } else { - build_endpoint("/v1", "/chat/completions") - } - } - ProviderId::Gemini => { - if request_path.starts_with("/v1/") { - build_endpoint("/v1beta/openai", "/chat/completions") - } else { - build_endpoint("/v1", "/chat/completions") - } - } - ProviderId::AmazonBedrock => { - if request_path.starts_with("/v1/") { - if !is_streaming { - build_endpoint("", &format!("/model/{}/converse", model_id)) - } else { - build_endpoint("", &format!("/model/{}/converse-stream", model_id)) - } - } else { - build_endpoint("/v1", "/chat/completions") - } - } - _ => build_endpoint("/v1", "/chat/completions"), - }, + } + SupportedAPIsFromClient::OpenAIChatCompletions(_) => { + // For Chat Completions API, use the standard chat/completions path + route_by_provider("/chat/completions") + } } } } @@ -198,22 +230,23 @@ mod tests { #[test] fn test_is_supported_endpoint() { // OpenAI endpoints - assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some()); + assert!(SupportedAPIsFromClient::from_endpoint("/v1/chat/completions").is_some()); // Anthropic endpoints - assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some()); + assert!(SupportedAPIsFromClient::from_endpoint("/v1/messages").is_some()); // Unsupported endpoints - assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some()); - assert!(!SupportedAPIs::from_endpoint("/v2/chat").is_some()); - assert!(!SupportedAPIs::from_endpoint("").is_some()); + assert!(!SupportedAPIsFromClient::from_endpoint("/v1/unknown").is_some()); + assert!(!SupportedAPIsFromClient::from_endpoint("/v2/chat").is_some()); + assert!(!SupportedAPIsFromClient::from_endpoint("").is_some()); } #[test] fn test_supported_endpoints() { let endpoints = supported_endpoints(); - assert_eq!(endpoints.len(), 2); // We have 2 APIs defined + assert_eq!(endpoints.len(), 3); // We have 3 APIs defined assert!(endpoints.contains(&"/v1/chat/completions")); assert!(endpoints.contains(&"/v1/messages")); + assert!(endpoints.contains(&"/v1/responses")); } #[test] @@ -263,7 +296,7 @@ mod tests { #[test] fn test_target_endpoint_without_base_url_prefix() { - let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); // Test default OpenAI provider assert_eq!( @@ -340,7 +373,7 @@ mod tests { #[test] fn test_target_endpoint_with_base_url_prefix() { - let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); // Test Zhipu with custom base_url_path_prefix assert_eq!( @@ -405,7 +438,7 @@ mod tests { #[test] fn test_target_endpoint_with_empty_base_url_prefix() { - let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); // Test with just slashes - trims to empty, uses provider default assert_eq!( @@ -434,7 +467,7 @@ mod tests { #[test] fn test_amazon_bedrock_endpoints() { - let api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + let api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); // Test Bedrock non-streaming without prefix assert_eq!( @@ -487,7 +520,7 @@ mod tests { #[test] fn test_anthropic_messages_endpoint() { - let api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + let api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); // Test Anthropic without prefix assert_eq!( @@ -516,7 +549,7 @@ mod tests { #[test] fn test_non_v1_request_paths() { - let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); // Test Groq with non-v1 path (should use default) assert_eq!( @@ -557,7 +590,7 @@ mod tests { #[test] fn test_azure_openai_with_query_params() { - let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); // Test Azure without prefix - should include query params assert_eq!( diff --git a/crates/hermesllm/src/clients/mod.rs b/crates/hermesllm/src/clients/mod.rs index b93f910e..6841804f 100644 --- a/crates/hermesllm/src/clients/mod.rs +++ b/crates/hermesllm/src/clients/mod.rs @@ -1,9 +1,8 @@ pub mod endpoints; pub mod lib; -pub mod transformer; // Re-export the main items for easier access -pub use endpoints::{identify_provider, SupportedAPIs}; +pub use endpoints::*; pub use lib::*; // Note: transformer module contains TryFrom trait implementations that are automatically available diff --git a/crates/hermesllm/src/clients/transformer.rs b/crates/hermesllm/src/clients/transformer.rs deleted file mode 100644 index cbb9bbe7..00000000 --- a/crates/hermesllm/src/clients/transformer.rs +++ /dev/null @@ -1,694 +0,0 @@ -// Re-export new transformation modules for backward compatibility - -//KEEPING THE TESTS TO MAKE SURE ALL THE REFACTORING DIDN'T BREAK ANYTHING - -// ============================================================================ -// TESTS -// ============================================================================ - -#[cfg(test)] -mod tests { - use crate::apis::anthropic::*; - use crate::apis::openai::*; - use crate::transforms::*; - use serde_json::json; - type AnthropicMessagesRequest = MessagesRequest; - - #[test] - fn test_anthropic_to_openai_basic_request() { - let anthropic_req = AnthropicMessagesRequest { - model: "claude-3-sonnet-20240229".to_string(), - system: Some(MessagesSystemPrompt::Single("You are helpful".to_string())), - messages: vec![MessagesMessage { - role: MessagesRole::User, - content: MessagesMessageContent::Single("Hello, world!".to_string()), - }], - max_tokens: 1024, - container: None, - mcp_servers: None, - service_tier: None, - thinking: None, - temperature: Some(0.7), - top_p: Some(0.9), - top_k: Some(50), - stream: Some(false), - stop_sequences: Some(vec!["STOP".to_string()]), - tools: None, - tool_choice: None, - metadata: None, - }; - - let openai_req: ChatCompletionsRequest = anthropic_req.try_into().unwrap(); - - assert_eq!(openai_req.model, "claude-3-sonnet-20240229"); - assert_eq!(openai_req.messages.len(), 2); // system + user message - assert_eq!(openai_req.max_completion_tokens, Some(1024)); - assert_eq!(openai_req.temperature, Some(0.7)); - assert_eq!(openai_req.top_p, Some(0.9)); - assert_eq!(openai_req.stream, Some(false)); - assert_eq!(openai_req.stop, Some(vec!["STOP".to_string()])); - } - - #[test] - fn test_roundtrip_consistency() { - // Test that converting back and forth maintains consistency - let original_anthropic = AnthropicMessagesRequest { - model: "claude-3-sonnet".to_string(), - system: Some(MessagesSystemPrompt::Single("System prompt".to_string())), - messages: vec![MessagesMessage { - role: MessagesRole::User, - content: MessagesMessageContent::Single("User message".to_string()), - }], - max_tokens: 1000, - container: None, - mcp_servers: None, - service_tier: None, - thinking: None, - temperature: Some(0.5), - top_p: Some(1.0), - top_k: None, - stream: Some(false), - stop_sequences: None, - tools: None, - tool_choice: None, - metadata: None, - }; - - // Convert to OpenAI and back - let openai_req: ChatCompletionsRequest = original_anthropic.clone().try_into().unwrap(); - let roundtrip_anthropic: AnthropicMessagesRequest = openai_req.try_into().unwrap(); - - // Check key fields are preserved - assert_eq!(original_anthropic.model, roundtrip_anthropic.model); - assert_eq!( - original_anthropic.max_tokens, - roundtrip_anthropic.max_tokens - ); - assert_eq!( - original_anthropic.temperature, - roundtrip_anthropic.temperature - ); - assert_eq!(original_anthropic.top_p, roundtrip_anthropic.top_p); - assert_eq!(original_anthropic.stream, roundtrip_anthropic.stream); - assert_eq!( - original_anthropic.messages.len(), - roundtrip_anthropic.messages.len() - ); - } - - #[test] - fn test_tool_choice_auto() { - let anthropic_req = AnthropicMessagesRequest { - model: "claude-3".to_string(), - system: None, - messages: vec![], - max_tokens: 100, - container: None, - mcp_servers: None, - service_tier: None, - thinking: None, - temperature: None, - top_p: None, - top_k: None, - stream: None, - stop_sequences: None, - tools: Some(vec![MessagesTool { - name: "test_tool".to_string(), - description: Some("A test tool".to_string()), - input_schema: json!({"type": "object"}), - }]), - tool_choice: Some(MessagesToolChoice { - kind: MessagesToolChoiceType::Auto, - name: None, - disable_parallel_tool_use: Some(true), - }), - metadata: None, - }; - - let openai_req: ChatCompletionsRequest = anthropic_req.try_into().unwrap(); - - assert!(openai_req.tools.is_some()); - assert_eq!(openai_req.tools.as_ref().unwrap().len(), 1); - - if let Some(ToolChoice::Type(choice)) = openai_req.tool_choice { - assert_eq!(choice, ToolChoiceType::Auto); - } else { - panic!("Expected auto tool choice"); - } - - assert_eq!(openai_req.parallel_tool_calls, Some(false)); - } - - #[test] - fn test_default_max_tokens_used_when_openai_has_none() { - // Test that DEFAULT_MAX_TOKENS is used when OpenAI request has no max_tokens - let openai_req = ChatCompletionsRequest { - model: "gpt-4".to_string(), - messages: vec![Message { - role: Role::User, - content: MessageContent::Text("Hello".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, - }], - max_tokens: None, // No max_tokens specified - ..Default::default() - }; - - let anthropic_req: AnthropicMessagesRequest = openai_req.try_into().unwrap(); - - assert_eq!(anthropic_req.max_tokens, DEFAULT_MAX_TOKENS); - } - - #[test] - fn test_anthropic_message_start_streaming() { - let event = MessagesStreamEvent::MessageStart { - message: MessagesStreamMessage { - id: "msg_stream_123".to_string(), - obj_type: "message".to_string(), - role: MessagesRole::Assistant, - content: vec![], - model: "claude-3".to_string(), - stop_reason: None, - stop_sequence: None, - usage: MessagesUsage { - input_tokens: 5, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }, - }; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - assert_eq!(openai_resp.id, "msg_stream_123"); - assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk")); - assert_eq!(openai_resp.model, "claude-3"); - assert_eq!(openai_resp.choices.len(), 1); - - let choice = &openai_resp.choices[0]; - assert_eq!(choice.index, 0); - assert_eq!(choice.delta.role, Some(Role::Assistant)); - assert_eq!(choice.delta.content, None); - assert_eq!(choice.finish_reason, None); - } - - #[test] - fn test_anthropic_content_block_delta_streaming() { - let event = MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::TextDelta { - text: "Hello, world!".to_string(), - }, - }; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk")); - assert_eq!(openai_resp.choices.len(), 1); - - let choice = &openai_resp.choices[0]; - assert_eq!(choice.index, 0); - assert_eq!(choice.delta.content, Some("Hello, world!".to_string())); - assert_eq!(choice.delta.role, None); - assert_eq!(choice.finish_reason, None); - } - - #[test] - fn test_anthropic_tool_use_streaming() { - // Test tool use start - let tool_start = MessagesStreamEvent::ContentBlockStart { - index: 0, - content_block: MessagesContentBlock::ToolUse { - id: "call_123".to_string(), - name: "get_weather".to_string(), - input: json!({}), - cache_control: None, - }, - }; - - let openai_resp: ChatCompletionsStreamResponse = tool_start.try_into().unwrap(); - - assert_eq!(openai_resp.choices.len(), 1); - let choice = &openai_resp.choices[0]; - assert!(choice.delta.tool_calls.is_some()); - - let tool_calls = choice.delta.tool_calls.as_ref().unwrap(); - assert_eq!(tool_calls.len(), 1); - assert_eq!(tool_calls[0].id, Some("call_123".to_string())); - assert_eq!( - tool_calls[0].function.as_ref().unwrap().name, - Some("get_weather".to_string()) - ); - } - - #[test] - fn test_anthropic_tool_input_delta_streaming() { - let event = MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::InputJsonDelta { - partial_json: r#"{"location": "San Francisco"#.to_string(), - }, - }; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - assert_eq!(openai_resp.choices.len(), 1); - let choice = &openai_resp.choices[0]; - assert!(choice.delta.tool_calls.is_some()); - - let tool_calls = choice.delta.tool_calls.as_ref().unwrap(); - assert_eq!(tool_calls.len(), 1); - assert_eq!( - tool_calls[0].function.as_ref().unwrap().arguments, - Some(r#"{"location": "San Francisco"#.to_string()) - ); - } - - #[test] - fn test_anthropic_message_delta_with_usage() { - let event = MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: MessagesStopReason::EndTurn, - stop_sequence: None, - }, - usage: MessagesUsage { - input_tokens: 10, - output_tokens: 25, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - assert_eq!(openai_resp.choices.len(), 1); - let choice = &openai_resp.choices[0]; - assert_eq!(choice.finish_reason, Some(FinishReason::Stop)); - - assert!(openai_resp.usage.is_some()); - let usage = openai_resp.usage.unwrap(); - assert_eq!(usage.prompt_tokens, 10); - assert_eq!(usage.completion_tokens, 25); - assert_eq!(usage.total_tokens, 35); - } - - #[test] - fn test_anthropic_message_stop_streaming() { - let event = MessagesStreamEvent::MessageStop; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - assert_eq!(openai_resp.choices.len(), 1); - let choice = &openai_resp.choices[0]; - assert_eq!(choice.finish_reason, Some(FinishReason::Stop)); - } - - #[test] - fn test_anthropic_ping_streaming() { - let event = MessagesStreamEvent::Ping; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk")); - assert_eq!(openai_resp.choices.len(), 0); // Ping has no choices - } - - #[test] - fn test_openai_to_anthropic_streaming_role_start() { - let openai_resp = ChatCompletionsStreamResponse { - id: "chatcmpl-123".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![StreamChoice { - index: 0, - delta: MessageDelta { - role: Some(Role::Assistant), - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - finish_reason: None, - logprobs: None, - }], - usage: None, - system_fingerprint: None, - service_tier: None, - }; - - let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - - match anthropic_event { - MessagesStreamEvent::MessageStart { message } => { - assert_eq!(message.id, "chatcmpl-123"); - assert_eq!(message.role, MessagesRole::Assistant); - assert_eq!(message.model, "gpt-4"); - } - _ => panic!("Expected MessageStart event"), - } - } - - #[test] - fn test_openai_to_anthropic_streaming_content_delta() { - let openai_resp = ChatCompletionsStreamResponse { - id: "chatcmpl-123".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![StreamChoice { - index: 0, - delta: MessageDelta { - role: None, - content: Some("Hello there!".to_string()), - refusal: None, - function_call: None, - tool_calls: None, - }, - finish_reason: None, - logprobs: None, - }], - usage: None, - system_fingerprint: None, - service_tier: None, - }; - - let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - - match anthropic_event { - MessagesStreamEvent::ContentBlockDelta { index, delta } => { - assert_eq!(index, 0); - match delta { - MessagesContentDelta::TextDelta { text } => { - assert_eq!(text, "Hello there!"); - } - _ => panic!("Expected TextDelta"), - } - } - _ => panic!("Expected ContentBlockDelta event"), - } - } - - #[test] - fn test_openai_to_anthropic_streaming_tool_calls() { - let openai_resp = ChatCompletionsStreamResponse { - id: "chatcmpl-123".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![StreamChoice { - index: 0, - delta: MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: 0, - id: Some("call_abc123".to_string()), - call_type: Some("function".to_string()), - function: Some(FunctionCallDelta { - name: Some("get_current_weather".to_string()), - arguments: Some("".to_string()), - }), - }]), - }, - finish_reason: None, - logprobs: None, - }], - usage: None, - system_fingerprint: None, - service_tier: None, - }; - - let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - - match anthropic_event { - MessagesStreamEvent::ContentBlockStart { - index, - content_block, - } => { - assert_eq!(index, 0); - match content_block { - MessagesContentBlock::ToolUse { id, name, .. } => { - assert_eq!(id, "call_abc123"); - assert_eq!(name, "get_current_weather"); - } - _ => panic!("Expected ToolUse content block"), - } - } - _ => panic!("Expected ContentBlockStart event"), - } - } - - #[test] - fn test_openai_to_anthropic_streaming_final_usage() { - let openai_resp = ChatCompletionsStreamResponse { - id: "chatcmpl-123".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![StreamChoice { - index: 0, - delta: MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - finish_reason: Some(FinishReason::Stop), - logprobs: None, - }], - usage: Some(Usage { - prompt_tokens: 15, - completion_tokens: 30, - total_tokens: 45, - prompt_tokens_details: None, - completion_tokens_details: None, - }), - system_fingerprint: None, - service_tier: None, - }; - - let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - - match anthropic_event { - MessagesStreamEvent::MessageDelta { delta, usage } => { - assert_eq!(delta.stop_reason, MessagesStopReason::EndTurn); - assert_eq!(usage.input_tokens, 15); - assert_eq!(usage.output_tokens, 30); - } - _ => panic!("Expected MessageDelta event"), - } - } - - #[test] - fn test_openai_empty_choices_to_anthropic_ping() { - let openai_resp = ChatCompletionsStreamResponse { - id: "chatcmpl-123".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![], // Empty choices - usage: None, - system_fingerprint: None, - service_tier: None, - }; - - let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - - match anthropic_event { - MessagesStreamEvent::Ping => { - // Expected behavior - } - _ => panic!("Expected Ping event for empty choices"), - } - } - - #[test] - fn test_streaming_roundtrip_consistency() { - // Test that streaming events can roundtrip through conversions - let original_event = MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::TextDelta { - text: "Test message".to_string(), - }, - }; - - // Convert to OpenAI and back - let openai_resp: ChatCompletionsStreamResponse = original_event.try_into().unwrap(); - let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - - // Verify the roundtrip maintains the essential information - match roundtrip_event { - MessagesStreamEvent::ContentBlockDelta { index, delta } => { - assert_eq!(index, 0); - match delta { - MessagesContentDelta::TextDelta { text } => { - assert_eq!(text, "Test message"); - } - _ => panic!("Expected TextDelta after roundtrip"), - } - } - _ => panic!("Expected ContentBlockDelta after roundtrip"), - } - } - - #[test] - fn test_streaming_tool_argument_accumulation() { - // Test multiple tool argument deltas that should accumulate - let tool_start = MessagesStreamEvent::ContentBlockStart { - index: 0, - content_block: MessagesContentBlock::ToolUse { - id: "call_weather".to_string(), - name: "get_weather".to_string(), - input: json!({}), - cache_control: None, - }, - }; - - let arg_delta1 = MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::InputJsonDelta { - partial_json: r#"{"location": "#.to_string(), - }, - }; - - let arg_delta2 = MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::InputJsonDelta { - partial_json: r#"San Francisco", "unit": "fahrenheit"}"#.to_string(), - }, - }; - - // Test that each delta converts properly to OpenAI format - let openai_start: ChatCompletionsStreamResponse = tool_start.try_into().unwrap(); - let openai_delta1: ChatCompletionsStreamResponse = arg_delta1.try_into().unwrap(); - let openai_delta2: ChatCompletionsStreamResponse = arg_delta2.try_into().unwrap(); - - // Verify tool start - let tool_calls = &openai_start.choices[0].delta.tool_calls.as_ref().unwrap(); - assert_eq!(tool_calls[0].id, Some("call_weather".to_string())); - assert_eq!( - tool_calls[0].function.as_ref().unwrap().name, - Some("get_weather".to_string()) - ); - - // Verify argument deltas - let args1 = &openai_delta1.choices[0].delta.tool_calls.as_ref().unwrap()[0] - .function - .as_ref() - .unwrap() - .arguments; - assert_eq!(args1, &Some(r#"{"location": "#.to_string())); - - let args2 = &openai_delta2.choices[0].delta.tool_calls.as_ref().unwrap()[0] - .function - .as_ref() - .unwrap() - .arguments; - assert_eq!( - args2, - &Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string()) - ); - } - - #[test] - fn test_streaming_multiple_finish_reasons() { - // Test different finish reasons in streaming - let test_cases = vec![ - (MessagesStopReason::EndTurn, FinishReason::Stop), - (MessagesStopReason::MaxTokens, FinishReason::Length), - (MessagesStopReason::ToolUse, FinishReason::ToolCalls), - (MessagesStopReason::StopSequence, FinishReason::Stop), - ]; - - for (anthropic_reason, expected_openai_reason) in test_cases { - let event = MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: anthropic_reason.clone(), - stop_sequence: None, - }, - usage: MessagesUsage { - input_tokens: 10, - output_tokens: 20, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - assert_eq!( - openai_resp.choices[0].finish_reason, - Some(expected_openai_reason) - ); - - // Test reverse conversion - let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - match roundtrip_event { - MessagesStreamEvent::MessageDelta { delta, .. } => { - // Note: Some precision may be lost in roundtrip due to mapping differences - assert!(matches!( - delta.stop_reason, - MessagesStopReason::EndTurn - | MessagesStopReason::MaxTokens - | MessagesStopReason::ToolUse - | MessagesStopReason::StopSequence - )); - } - _ => panic!("Expected MessageDelta after roundtrip"), - } - } - } - - #[test] - fn test_streaming_error_handling() { - // Test that malformed streaming events are handled gracefully - let openai_resp_with_missing_data = ChatCompletionsStreamResponse { - id: "test".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 1234567890, - model: "test".to_string(), - choices: vec![StreamChoice { - index: 0, - delta: MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - finish_reason: None, - logprobs: None, - }], - usage: None, - system_fingerprint: None, - service_tier: None, - }; - - // Should convert to Ping when no meaningful content - let anthropic_event: MessagesStreamEvent = - openai_resp_with_missing_data.try_into().unwrap(); - assert!(matches!(anthropic_event, MessagesStreamEvent::Ping)); - } - - #[test] - fn test_streaming_content_block_stop() { - let event = MessagesStreamEvent::ContentBlockStop { index: 0 }; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - // ContentBlockStop should produce an empty chunk - assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk")); - assert_eq!(openai_resp.choices.len(), 1); - - let choice = &openai_resp.choices[0]; - assert_eq!(choice.delta.role, None); - assert_eq!(choice.delta.content, None); - assert_eq!(choice.delta.tool_calls, None); - assert_eq!(choice.finish_reason, None); - } -} diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 77289f4b..918fd4e9 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -6,18 +6,21 @@ pub mod clients; pub mod providers; pub mod transforms; // Re-export important types and traits -pub use apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; -pub use apis::sse::{SseEvent, SseStreamIter}; +pub use apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; +pub use apis::streaming_shapes::sse::{SseEvent, SseStreamIter}; pub use aws_smithy_eventstream::frame::DecodedFrame; pub use providers::id::ProviderId; pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType}; pub use providers::response::{ - ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse, - ProviderStreamResponseType, TokenUsage, + ProviderResponse, ProviderResponseType, TokenUsage, ProviderResponseError +}; +pub use providers::streaming_response::{ + ProviderStreamResponse, ProviderStreamResponseType }; //TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; +pub const OPENAI_RESPONSES_API_PATH: &str = "/v1/responses"; pub const MESSAGES_PATH: &str = "/v1/messages"; #[cfg(test)] @@ -42,9 +45,9 @@ mod tests { data: [DONE] "#; - use crate::clients::endpoints::SupportedAPIs; + use crate::clients::endpoints::SupportedAPIsFromClient; let client_api = - SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); + SupportedAPIsFromClient::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); @@ -79,9 +82,16 @@ mod tests { assert_eq!(stream_response.content_delta(), Some("Hello")); assert!(!stream_response.is_final()); - // Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE]) + // Test that stream ends properly with [DONE] + // The iterator should return the [DONE] event, then None + let done_event = streaming_iter.next(); + assert!(done_event.is_some(), "Should get [DONE] event"); + let done_event = done_event.unwrap(); + assert!(done_event.is_done(), "[DONE] event should be marked as done"); + + // After [DONE], iterator should return None let final_event = streaming_iter.next(); - assert!(final_event.is_none()); // Should be None because iterator stops at [DONE] + assert!(final_event.is_none(), "Iterator should return None after [DONE]"); } /// Test AWS Event Stream decoding for Bedrock ConverseStream responses. diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index 94a6205a..344a795f 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -1,5 +1,5 @@ use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi}; -use crate::clients::endpoints::{SupportedAPIs, SupportedUpstreamAPIs}; +use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; use std::fmt::Display; /// Provider identifier enum - simple enum for identifying providers @@ -51,19 +51,24 @@ impl ProviderId { /// Given a client API, return the compatible upstream API for this provider pub fn compatible_api_for_client( &self, - client_api: &SupportedAPIs, + client_api: &SupportedAPIsFromClient, is_streaming: bool, ) -> SupportedUpstreamAPIs { match (self, client_api) { // Claude/Anthropic providers natively support Anthropic APIs - (ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => { + (ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => { SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) } ( ProviderId::Anthropic, - SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + SupportedAPIsFromClient::OpenAIChatCompletions(_), ) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + // Anthropic doesn't support Responses API, fall back to chat completions + (ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { + SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions) + } + // OpenAI-compatible providers only support OpenAI chat completions ( ProviderId::OpenAI @@ -80,7 +85,7 @@ impl ProviderId { | ProviderId::Moonshotai | ProviderId::Zhipu | ProviderId::Qwen, - SupportedAPIs::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), ) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), ( @@ -98,11 +103,16 @@ impl ProviderId { | ProviderId::Moonshotai | ProviderId::Zhipu | ProviderId::Qwen, - SupportedAPIs::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), ) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + // OpenAI Responses API - only OpenAI supports this + (ProviderId::OpenAI, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { + SupportedUpstreamAPIs::OpenAIResponsesAPI(OpenAIApi::Responses) + } + // Amazon Bedrock natively supports Bedrock APIs - (ProviderId::AmazonBedrock, SupportedAPIs::OpenAIChatCompletions(_)) => { + (ProviderId::AmazonBedrock, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => { if is_streaming { SupportedUpstreamAPIs::AmazonBedrockConverseStream( AmazonBedrockApi::ConverseStream, @@ -111,7 +121,7 @@ impl ProviderId { SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) } } - (ProviderId::AmazonBedrock, SupportedAPIs::AnthropicMessagesAPI(_)) => { + (ProviderId::AmazonBedrock, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => { if is_streaming { SupportedUpstreamAPIs::AmazonBedrockConverseStream( AmazonBedrockApi::ConverseStream, @@ -120,6 +130,20 @@ impl ProviderId { SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) } } + (ProviderId::AmazonBedrock, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { + if is_streaming { + SupportedUpstreamAPIs::AmazonBedrockConverseStream( + AmazonBedrockApi::ConverseStream, + ) + } else { + SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) + } + } + + // Non-OpenAI providers: if client requested the Responses API, fall back to Chat Completions + (_, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { + SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions) + } } } } diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index 97b14285..4343022f 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -6,7 +6,9 @@ pub mod id; pub mod request; pub mod response; +pub mod streaming_response; pub use id::ProviderId; pub use request::{ProviderRequest, ProviderRequestError, ProviderRequestType}; -pub use response::{ProviderResponse, ProviderResponseType, ProviderStreamResponse, TokenUsage}; +pub use response::{ProviderResponse, ProviderResponseType, TokenUsage}; +pub use streaming_response::{ProviderStreamResponse, ProviderStreamResponseType}; diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index a8bcfa29..daeebe70 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -2,19 +2,21 @@ use crate::apis::anthropic::MessagesRequest; use crate::apis::openai::ChatCompletionsRequest; use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest}; -use crate::clients::endpoints::SupportedAPIs; +use crate::apis::openai_responses::ResponsesAPIRequest; +use crate::clients::endpoints::SupportedAPIsFromClient; use crate::clients::endpoints::SupportedUpstreamAPIs; use serde_json::Value; use std::collections::HashMap; use std::error::Error; use std::fmt; -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum ProviderRequestType { ChatCompletionsRequest(ChatCompletionsRequest), MessagesRequest(MessagesRequest), BedrockConverse(ConverseRequest), BedrockConverseStream(ConverseStreamRequest), + ResponsesAPIRequest(ResponsesAPIRequest), //add more request types here } pub trait ProviderRequest: Send + Sync { @@ -49,6 +51,7 @@ impl ProviderRequest for ProviderRequestType { Self::MessagesRequest(r) => r.model(), Self::BedrockConverse(r) => r.model(), Self::BedrockConverseStream(r) => r.model(), + Self::ResponsesAPIRequest(r) => r.model(), } } @@ -58,6 +61,7 @@ impl ProviderRequest for ProviderRequestType { Self::MessagesRequest(r) => r.set_model(model), Self::BedrockConverse(r) => r.set_model(model), Self::BedrockConverseStream(r) => r.set_model(model), + Self::ResponsesAPIRequest(r) => r.set_model(model), } } @@ -67,6 +71,7 @@ impl ProviderRequest for ProviderRequestType { Self::MessagesRequest(r) => r.is_streaming(), Self::BedrockConverse(_) => false, Self::BedrockConverseStream(_) => true, + Self::ResponsesAPIRequest(r) => r.is_streaming(), } } @@ -76,6 +81,7 @@ impl ProviderRequest for ProviderRequestType { Self::MessagesRequest(r) => r.extract_messages_text(), Self::BedrockConverse(r) => r.extract_messages_text(), Self::BedrockConverseStream(r) => r.extract_messages_text(), + Self::ResponsesAPIRequest(r) => r.extract_messages_text(), } } @@ -85,6 +91,7 @@ impl ProviderRequest for ProviderRequestType { Self::MessagesRequest(r) => r.get_recent_user_message(), Self::BedrockConverse(r) => r.get_recent_user_message(), Self::BedrockConverseStream(r) => r.get_recent_user_message(), + Self::ResponsesAPIRequest(r) => r.get_recent_user_message(), } } @@ -94,6 +101,7 @@ impl ProviderRequest for ProviderRequestType { Self::MessagesRequest(r) => r.to_bytes(), Self::BedrockConverse(r) => r.to_bytes(), Self::BedrockConverseStream(r) => r.to_bytes(), + Self::ResponsesAPIRequest(r) => r.to_bytes(), } } @@ -103,6 +111,7 @@ impl ProviderRequest for ProviderRequestType { Self::MessagesRequest(r) => r.metadata(), Self::BedrockConverse(r) => r.metadata(), Self::BedrockConverseStream(r) => r.metadata(), + Self::ResponsesAPIRequest(r) => r.metadata(), } } @@ -112,18 +121,19 @@ impl ProviderRequest for ProviderRequestType { Self::MessagesRequest(r) => r.remove_metadata_key(key), Self::BedrockConverse(r) => r.remove_metadata_key(key), Self::BedrockConverseStream(r) => r.remove_metadata_key(key), + Self::ResponsesAPIRequest(r) => r.remove_metadata_key(key), } } } /// Parse the client API from a byte slice. -impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType { +impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType { type Error = std::io::Error; - fn try_from((bytes, client_api): (&[u8], &SupportedAPIs)) -> Result { + fn try_from((bytes, client_api): (&[u8], &SupportedAPIsFromClient)) -> Result { // Use SupportedApi to determine the appropriate request type match client_api { - SupportedAPIs::OpenAIChatCompletions(_) => { + SupportedAPIsFromClient::OpenAIChatCompletions(_) => { let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -131,11 +141,20 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType { chat_completion_request, )) } - SupportedAPIs::AnthropicMessagesAPI(_) => { + SupportedAPIsFromClient::AnthropicMessagesAPI(_) => { let messages_request: MessagesRequest = MessagesRequest::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; Ok(ProviderRequestType::MessagesRequest(messages_request)) } + + SupportedAPIsFromClient::OpenAIResponsesAPI(_) => { + let responses_apirequest: ResponsesAPIRequest = + ResponsesAPIRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::ResponsesAPIRequest( + responses_apirequest, + )) + } } } } @@ -148,17 +167,13 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT (client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs), ) -> Result { match (client_request, upstream_api) { - // Same API - no conversion needed, just clone the reference + // ============================================================================ + // ChatCompletionsRequest conversions + // ============================================================================ ( ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::OpenAIChatCompletions(_), ) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)), - ( - ProviderRequestType::MessagesRequest(messages_req), - SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - ) => Ok(ProviderRequestType::MessagesRequest(messages_req)), - - // Cross-API conversion - cloning is necessary for transformation ( ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AnthropicMessagesAPI(_), @@ -173,7 +188,45 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT })?; Ok(ProviderRequestType::MessagesRequest(messages_req)) } + ( + ProviderRequestType::ChatCompletionsRequest(chat_req), + SupportedUpstreamAPIs::AmazonBedrockConverse(_), + ) => { + let bedrock_req = ConverseRequest::try_from(chat_req) + .map_err(|e| ProviderRequestError { + message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e), + source: Some(Box::new(e)) + })?; + Ok(ProviderRequestType::BedrockConverse(bedrock_req)) + } + ( + ProviderRequestType::ChatCompletionsRequest(chat_req), + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + ) => { + let bedrock_req = ConverseStreamRequest::try_from(chat_req) + .map_err(|e| ProviderRequestError { + message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock Stream request: {}", e), + source: Some(Box::new(e)) + })?; + Ok(ProviderRequestType::BedrockConverseStream(bedrock_req)) + } + ( + ProviderRequestType::ChatCompletionsRequest(_), + SupportedUpstreamAPIs::OpenAIResponsesAPI(_), + ) => { + Err(ProviderRequestError { + message: "Conversion from ChatCompletionsRequest to ResponsesAPIRequest is not supported. ResponsesAPI can only be used as a client API, not as an upstream API.".to_string(), + source: None, + }) + } + // ============================================================================ + // MessagesRequest conversions + // ============================================================================ + ( + ProviderRequestType::MessagesRequest(messages_req), + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + ) => Ok(ProviderRequestType::MessagesRequest(messages_req)), ( ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::OpenAIChatCompletions(_), @@ -189,31 +242,6 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT })?; Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) } - - // Cross-API conversions: OpenAI/Anthropic to Amazon Bedrock - ( - ProviderRequestType::ChatCompletionsRequest(chat_req), - SupportedUpstreamAPIs::AmazonBedrockConverse(_), - ) => { - let bedrock_req = ConverseRequest::try_from(chat_req) - .map_err(|e| ProviderRequestError { - message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e), - source: Some(Box::new(e)) - })?; - Ok(ProviderRequestType::BedrockConverse(bedrock_req)) - } - - ( - ProviderRequestType::ChatCompletionsRequest(chat_req), - SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), - ) => { - let bedrock_req = ConverseStreamRequest::try_from(chat_req) - .map_err(|e| ProviderRequestError { - message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e), - source: Some(Box::new(e)) - })?; - Ok(ProviderRequestType::BedrockConverse(bedrock_req)) - } ( ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AmazonBedrockConverse(_), @@ -235,7 +263,97 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT let bedrock_req = ConverseStreamRequest::try_from(messages_req).map_err(|e| { ProviderRequestError { message: format!( - "Failed to convert MessagesRequest to Amazon Bedrock request: {}", + "Failed to convert MessagesRequest to Amazon Bedrock Stream request: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::BedrockConverseStream(bedrock_req)) + } + ( + ProviderRequestType::MessagesRequest(_), + SupportedUpstreamAPIs::OpenAIResponsesAPI(_), + ) => { + Err(ProviderRequestError { + message: "Conversion from MessagesRequest to ResponsesAPIRequest is not supported. ResponsesAPI can only be used as a client API, not as an upstream API.".to_string(), + source: None, + }) + } + + // ============================================================================ + // ResponsesAPIRequest conversions (only converts TO other formats) + // ============================================================================ + ( + ProviderRequestType::ResponsesAPIRequest(responses_req), + SupportedUpstreamAPIs::OpenAIResponsesAPI(_), + ) => Ok(ProviderRequestType::ResponsesAPIRequest(responses_req)), + + // ResponsesAPI -> ChatCompletions (direct conversion) + ( + ProviderRequestType::ResponsesAPIRequest(responses_req), + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + ) => { + let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) + } + + // ResponsesAPI -> Anthropic Messages (via ChatCompletions) + ( + ProviderRequestType::ResponsesAPIRequest(responses_req), + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + ) => { + // Chain: ResponsesAPI -> ChatCompletions -> MessagesRequest + let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + + let messages_req = MessagesRequest::try_from(chat_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ChatCompletionsRequest to MessagesRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::MessagesRequest(messages_req)) + } + + // ResponsesAPI -> Bedrock Converse (via ChatCompletions) + ( + ProviderRequestType::ResponsesAPIRequest(responses_req), + SupportedUpstreamAPIs::AmazonBedrockConverse(_), + ) => { + // Chain: ResponsesAPI -> ChatCompletions -> ConverseRequest + let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + + let bedrock_req = ConverseRequest::try_from(chat_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e ), source: Some(Box::new(e)), @@ -244,13 +362,50 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT Ok(ProviderRequestType::BedrockConverse(bedrock_req)) } - // Amazon Bedrock to other APIs conversions + // ResponsesAPI -> Bedrock Converse Stream (via ChatCompletions) + ( + ProviderRequestType::ResponsesAPIRequest(responses_req), + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + ) => { + // Chain: ResponsesAPI -> ChatCompletions -> ConverseStreamRequest + let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + + let bedrock_req = ConverseStreamRequest::try_from(chat_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ChatCompletionsRequest to Amazon Bedrock Stream request: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::BedrockConverseStream(bedrock_req)) + } + + // ============================================================================ + // Amazon Bedrock conversions (not supported as client API) + // ============================================================================ + (ProviderRequestType::BedrockConverse(_), _) => { - todo!("Amazon Bedrock to ChatCompletionsRequest conversion not implemented yet") + Err(ProviderRequestError { + message: "Amazon Bedrock Converse is not supported as a client API. Only OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses APIs are supported as client APIs.".to_string(), + source: None, + }) } (ProviderRequestType::BedrockConverseStream(_), _) => { - todo!("Amazon Bedrock Stream to ChatCompletionsRequest conversion not implemented yet") + Err(ProviderRequestError { + message: "Amazon Bedrock Converse Stream is not supported as a client API. Only OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses APIs are supported as client APIs.".to_string(), + source: None, + }) } } } @@ -284,7 +439,7 @@ mod tests { use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest; use crate::apis::openai::ChatCompletionsRequest; use crate::apis::openai::OpenAIApi::ChatCompletions; - use crate::clients::endpoints::SupportedAPIs; + use crate::clients::endpoints::SupportedAPIsFromClient; use crate::transforms::lib::ExtractText; use serde_json::json; @@ -298,7 +453,7 @@ mod tests { ] }); let bytes = serde_json::to_vec(&req).unwrap(); - let api = SupportedAPIs::OpenAIChatCompletions(ChatCompletions); + let api = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions); let result = ProviderRequestType::try_from((bytes.as_slice(), &api)); assert!(result.is_ok()); match result.unwrap() { @@ -321,7 +476,7 @@ mod tests { ] }); let bytes = serde_json::to_vec(&req).unwrap(); - let endpoint = SupportedAPIs::AnthropicMessagesAPI(Messages); + let endpoint = SupportedAPIsFromClient::AnthropicMessagesAPI(Messages); let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint)); assert!(result.is_ok()); match result.unwrap() { @@ -343,7 +498,7 @@ mod tests { ] }); let bytes = serde_json::to_vec(&req).unwrap(); - let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions); + let endpoint = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions); let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint)); assert!(result.is_ok()); match result.unwrap() { @@ -366,7 +521,7 @@ mod tests { }); let bytes = serde_json::to_vec(&req).unwrap(); // Intentionally use OpenAI endpoint for Anthropic payload - let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions); + let endpoint = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions); let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint)); // Should parse as ChatCompletionsRequest, not error assert!(result.is_ok()); @@ -486,4 +641,272 @@ mod tests { let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens); assert_eq!(original_max_tokens, roundtrip_max_tokens); } + + #[test] + fn test_responses_api_request_from_bytes() { + use crate::apis::openai::OpenAIApi::Responses; + + let req = json!({ + "model": "gpt-4o", + "input": "Hello, how are you?" + }); + let bytes = serde_json::to_vec(&req).unwrap(); + let api = SupportedAPIsFromClient::OpenAIResponsesAPI(Responses); + let result = ProviderRequestType::try_from((bytes.as_slice(), &api)); + assert!(result.is_ok()); + match result.unwrap() { + ProviderRequestType::ResponsesAPIRequest(r) => { + assert_eq!(r.model, "gpt-4o"); + } + _ => panic!("Expected ResponsesAPIRequest variant"), + } + } + + #[test] + fn test_responses_api_to_chat_completions_conversion() { + use crate::apis::openai::OpenAIApi::ChatCompletions; + use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest}; + + let responses_req = ResponsesAPIRequest { + model: "gpt-4o".to_string(), + input: InputParam::Text("Hello, world!".to_string()), + temperature: Some(0.7), + top_p: Some(0.9), + max_output_tokens: Some(100), + stream: Some(false), + metadata: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + instructions: None, + modalities: None, + user: None, + store: None, + reasoning_effort: None, + include: None, + audio: None, + text: None, + service_tier: None, + top_logprobs: None, + stream_options: None, + truncation: None, + conversation: None, + previous_response_id: None, + max_tool_calls: None, + background: None, + }; + + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(ChatCompletions); + let result = ProviderRequestType::try_from(( + ProviderRequestType::ResponsesAPIRequest(responses_req), + &upstream_api, + )); + + assert!(result.is_ok()); + match result.unwrap() { + ProviderRequestType::ChatCompletionsRequest(chat_req) => { + assert_eq!(chat_req.model, "gpt-4o"); + assert_eq!(chat_req.temperature, Some(0.7)); + assert_eq!(chat_req.top_p, Some(0.9)); + assert_eq!(chat_req.max_completion_tokens, Some(100)); + assert_eq!(chat_req.messages.len(), 1); + } + _ => panic!("Expected ChatCompletionsRequest variant"), + } + } + + #[test] + fn test_responses_api_to_anthropic_messages_conversion() { + use crate::apis::anthropic::AnthropicApi::Messages; + use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest}; + + let responses_req = ResponsesAPIRequest { + model: "gpt-4o".to_string(), + input: InputParam::Text("Hello, Claude!".to_string()), + temperature: Some(0.8), + max_output_tokens: Some(150), + stream: Some(false), + metadata: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + instructions: Some("You are a helpful assistant".to_string()), + modalities: None, + user: None, + store: None, + reasoning_effort: None, + include: None, + audio: None, + text: None, + service_tier: None, + top_p: None, + top_logprobs: None, + stream_options: None, + truncation: None, + conversation: None, + previous_response_id: None, + max_tool_calls: None, + background: None, + }; + + let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(Messages); + let result = ProviderRequestType::try_from(( + ProviderRequestType::ResponsesAPIRequest(responses_req), + &upstream_api, + )); + + assert!(result.is_ok()); + match result.unwrap() { + ProviderRequestType::MessagesRequest(messages_req) => { + assert_eq!(messages_req.model, "gpt-4o"); + assert_eq!(messages_req.temperature, Some(0.8)); + assert_eq!(messages_req.max_tokens, 150); + // Instructions should be converted to system prompt via ChatCompletions conversion + // The conversion chain: ResponsesAPI -> ChatCompletions (system message) -> Anthropic (system prompt) + // But we need to check if the system prompt was actually set + assert_eq!(messages_req.messages.len(), 1); + } + _ => panic!("Expected MessagesRequest variant"), + } + } + + #[test] + fn test_responses_api_to_bedrock_conversion() { + use crate::apis::amazon_bedrock::AmazonBedrockApi::Converse; + use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest}; + + let responses_req = ResponsesAPIRequest { + model: "gpt-4o".to_string(), + input: InputParam::Text("Hello, Bedrock!".to_string()), + temperature: Some(0.5), + max_output_tokens: Some(200), + stream: Some(false), + metadata: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + instructions: None, + modalities: None, + user: None, + store: None, + reasoning_effort: None, + include: None, + audio: None, + text: None, + service_tier: None, + top_p: None, + top_logprobs: None, + stream_options: None, + truncation: None, + conversation: None, + previous_response_id: None, + max_tool_calls: None, + background: None, + }; + + let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverse(Converse); + let result = ProviderRequestType::try_from(( + ProviderRequestType::ResponsesAPIRequest(responses_req), + &upstream_api, + )); + + assert!(result.is_ok()); + match result.unwrap() { + ProviderRequestType::BedrockConverse(bedrock_req) => { + assert_eq!(bedrock_req.model_id, "gpt-4o"); + // Bedrock receives the converted request through ChatCompletions + assert!(!bedrock_req.messages.is_none()); + } + _ => panic!("Expected BedrockConverse variant"), + } + } + + #[test] + fn test_chat_completions_to_responses_api_not_supported() { + use crate::apis::openai::OpenAIApi::Responses; + use crate::apis::openai::{Message, MessageContent, Role}; + + let chat_req = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![Message { + role: Role::User, + content: MessageContent::Text("Hello!".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + ..Default::default() + }; + + let upstream_api = SupportedUpstreamAPIs::OpenAIResponsesAPI(Responses); + let result = ProviderRequestType::try_from(( + ProviderRequestType::ChatCompletionsRequest(chat_req), + &upstream_api, + )); + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.message.contains("ResponsesAPI can only be used as a client API")); + } + + #[test] + fn test_anthropic_messages_to_responses_api_not_supported() { + use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest; + use crate::apis::openai::OpenAIApi::Responses; + + let messages_req = AnthropicMessagesRequest { + model: "claude-3-sonnet".to_string(), + messages: vec![crate::apis::anthropic::MessagesMessage { + role: crate::apis::anthropic::MessagesRole::User, + content: crate::apis::anthropic::MessagesMessageContent::Single( + "Hello!".to_string(), + ), + }], + max_tokens: 100, + container: None, + mcp_servers: None, + service_tier: None, + thinking: None, + temperature: None, + top_p: None, + top_k: None, + stream: None, + stop_sequences: None, + system: None, + tools: None, + tool_choice: None, + metadata: None, + }; + + let upstream_api = SupportedUpstreamAPIs::OpenAIResponsesAPI(Responses); + let result = ProviderRequestType::try_from(( + ProviderRequestType::MessagesRequest(messages_req), + &upstream_api, + )); + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.message.contains("ResponsesAPI can only be used as a client API")); + } + + #[test] + fn test_bedrock_as_client_api_not_supported() { + use crate::apis::openai::OpenAIApi::ChatCompletions; + + // Create a simple Bedrock request (we'll use Default if available, or minimal construction) + let bedrock_req = ConverseRequest::default(); + + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(ChatCompletions); + let result = ProviderRequestType::try_from(( + ProviderRequestType::BedrockConverse(bedrock_req), + &upstream_api, + )); + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.message.contains("not supported as a client API")); + assert!(err + .message + .contains("OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses")); + } } diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index 54fda8c4..a2494c6d 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -2,38 +2,28 @@ use serde::Serialize; use std::convert::TryFrom; use std::error::Error; use std::fmt; - use crate::apis::amazon_bedrock::ConverseResponse; -use crate::apis::amazon_bedrock::ConverseStreamEvent; use crate::apis::anthropic::MessagesResponse; -use crate::apis::anthropic::MessagesStreamEvent; use crate::apis::openai::ChatCompletionsResponse; -use crate::apis::openai::ChatCompletionsStreamResponse; -use crate::apis::sse::SseEvent; -use crate::clients::endpoints::SupportedAPIs; +use crate::apis::openai_responses::ResponsesAPIResponse; +use crate::clients::endpoints::SupportedAPIsFromClient; use crate::clients::endpoints::SupportedUpstreamAPIs; use crate::providers::id::ProviderId; -/// Trait for token usage information -pub trait TokenUsage { - fn completion_tokens(&self) -> usize; - fn prompt_tokens(&self) -> usize; - fn total_tokens(&self) -> usize; -} #[derive(Serialize, Debug, Clone)] #[serde(untagged)] pub enum ProviderResponseType { ChatCompletionsResponse(ChatCompletionsResponse), MessagesResponse(MessagesResponse), + ResponsesAPIResponse(ResponsesAPIResponse), } -#[derive(Serialize, Debug, Clone)] -#[serde(untagged)] -pub enum ProviderStreamResponseType { - ChatCompletionsStreamResponse(ChatCompletionsStreamResponse), - MessagesStreamEvent(MessagesStreamEvent), - ConverseStreamEvent(ConverseStreamEvent), +/// Trait for token usage information +pub trait TokenUsage { + fn completion_tokens(&self) -> usize; + fn prompt_tokens(&self) -> usize; + fn total_tokens(&self) -> usize; } pub trait ProviderResponse: Send + Sync { @@ -52,6 +42,7 @@ impl ProviderResponse for ProviderResponseType { match self { ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(), ProviderResponseType::MessagesResponse(resp) => resp.usage(), + ProviderResponseType::ResponsesAPIResponse(resp) => resp.usage.as_ref().map(|u| u as &dyn TokenUsage), } } @@ -59,89 +50,27 @@ impl ProviderResponse for ProviderResponseType { match self { ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(), ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(), - } - } -} -pub trait ProviderStreamResponse: Send + Sync { - /// Get the content delta for this chunk - fn content_delta(&self) -> Option<&str>; - - /// Check if this is the final chunk in the stream - fn is_final(&self) -> bool; - - /// Get role information if available - fn role(&self) -> Option<&str>; - - /// Get event type for SSE streaming (used by Anthropic) - fn event_type(&self) -> Option<&str>; -} - -impl ProviderStreamResponse for ProviderStreamResponseType { - fn content_delta(&self) -> Option<&str> { - match self { - ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.content_delta(), - ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.content_delta(), - ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.content_delta(), - } - } - - fn is_final(&self) -> bool { - match self { - ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.is_final(), - ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.is_final(), - ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.is_final(), - } - } - - fn role(&self) -> Option<&str> { - match self { - ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.role(), - ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.role(), - ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.role(), - } - } - - fn event_type(&self) -> Option<&str> { - match self { - ProviderStreamResponseType::ChatCompletionsStreamResponse(_resp) => None, // OpenAI doesn't use event types - ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.event_type(), - ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.event_type(), // Bedrock doesn't use event types - } - } -} - -impl Into for ProviderStreamResponseType { - fn into(self) -> String { - match self { - ProviderStreamResponseType::MessagesStreamEvent(event) => { - // Use the Into implementation for proper SSE formatting with event lines - event.into() - } - ProviderStreamResponseType::ConverseStreamEvent(event) => { - // Use the Into implementation for proper SSE formatting with event lines - event.into() - } - ProviderStreamResponseType::ChatCompletionsStreamResponse(_) => { - // For OpenAI, use simple data line format - let json = serde_json::to_string(&self).unwrap_or_default(); - format!("data: {}\n\n", json) + ProviderResponseType::ResponsesAPIResponse(resp) => { + resp.usage.as_ref().map(|u| { + (u.input_tokens as usize, u.output_tokens as usize, u.total_tokens as usize) + }) } } } } // --- Response transformation logic for client API compatibility --- -impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { +impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderResponseType { type Error = std::io::Error; fn try_from( - (bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId), + (bytes, client_api, provider_id): (&[u8], &SupportedAPIsFromClient, &ProviderId), ) -> Result { let upstream_api = provider_id.compatible_api_for_client(client_api, false); match (&upstream_api, client_api) { ( SupportedUpstreamAPIs::OpenAIChatCompletions(_), - SupportedAPIs::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), ) => { let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -149,7 +78,7 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { } ( SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - SupportedAPIs::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), ) => { let resp: MessagesResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -157,7 +86,7 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { } ( SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - SupportedAPIs::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), ) => { let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -174,7 +103,7 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { } ( SupportedUpstreamAPIs::OpenAIChatCompletions(_), - SupportedAPIs::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), ) => { let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -191,7 +120,7 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { // Amazon Bedrock transformations ( SupportedUpstreamAPIs::AmazonBedrockConverse(_), - SupportedAPIs::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), ) => { let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -207,7 +136,7 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { } ( SupportedUpstreamAPIs::AmazonBedrockConverse(_), - SupportedAPIs::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), ) => { let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -221,6 +150,80 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { })?; Ok(ProviderResponseType::MessagesResponse(messages_resp)) } + ( + SupportedUpstreamAPIs::OpenAIResponsesAPI(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + let resp: ResponsesAPIResponse = ResponsesAPIResponse::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderResponseType::ResponsesAPIResponse(resp)) + } + ( + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to ResponsesAPI format using the transformer + let responses_resp: ResponsesAPIResponse = chat_completions_response.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::ResponsesAPIResponse(responses_resp)) + } + ( + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + + //Chain transform: Anthropic Messages -> OpenAI ChatCompletions -> ResponsesAPI + let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to ChatCompletions format using the transformer + let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + + let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::ResponsesAPIResponse(response_api)) + } + ( + SupportedUpstreamAPIs::AmazonBedrockConverse(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + // Chain transform: Bedrock Converse -> ChatCompletions -> ResponsesAPI + let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to ChatCompletions format + let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Bedrock to ChatCompletions transformation error: {}", e), + ) + })?; + + // Transform to ResponsesAPI format + let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("ChatCompletions to ResponsesAPI transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::ResponsesAPIResponse(response_api)) + } _ => Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "Unsupported API combination for response transformation", @@ -229,247 +232,6 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { } } -// Stream response transformation logic for client API compatibility -impl TryFrom<(&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)> for ProviderStreamResponseType { - type Error = Box; - - fn try_from( - (bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedUpstreamAPIs), - ) -> Result { - // Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion - if bytes == b"[DONE]" && matches!(client_api, SupportedAPIs::AnthropicMessagesAPI(_)) { - return Ok(ProviderStreamResponseType::MessagesStreamEvent( - crate::apis::anthropic::MessagesStreamEvent::MessageStop, - )); - } - match (upstream_api, client_api) { - // OpenAI upstream - ( - SupportedUpstreamAPIs::OpenAIChatCompletions(_), - SupportedAPIs::OpenAIChatCompletions(_), - ) => { - let resp = serde_json::from_slice(bytes)?; - Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( - resp, - )) - } - ( - SupportedUpstreamAPIs::OpenAIChatCompletions(_), - SupportedAPIs::AnthropicMessagesAPI(_), - ) => { - let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = - serde_json::from_slice(bytes)?; - let anthropic_resp = openai_resp.try_into()?; - Ok(ProviderStreamResponseType::MessagesStreamEvent( - anthropic_resp, - )) - } - - // Anthropic upstream - ( - SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - SupportedAPIs::AnthropicMessagesAPI(_), - ) => { - let resp = serde_json::from_slice(bytes)?; - Ok(ProviderStreamResponseType::MessagesStreamEvent(resp)) - } - ( - SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - SupportedAPIs::OpenAIChatCompletions(_), - ) => { - let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = - serde_json::from_slice(bytes)?; - let openai_resp = anthropic_resp.try_into()?; - Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( - openai_resp, - )) - } - - // Amazon Bedrock ConverseStream upstream - ( - SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), - SupportedAPIs::AnthropicMessagesAPI(_), - ) => { - let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent = - serde_json::from_slice(bytes)?; - let anthropic_resp = bedrock_resp.try_into()?; - Ok(ProviderStreamResponseType::MessagesStreamEvent( - anthropic_resp, - )) - } - _ => Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "Unsupported API combination for response transformation", - ) - .into()), - } - } -} - -// TryFrom implementation to convert raw bytes to SseEvent with parsed provider response -impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent { - type Error = Box; - - fn try_from( - (sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs), - ) -> Result { - // Create a new transformed event based on the original - let mut transformed_event = sse_event; - - // Handle [DONE] marker early - don't try to parse as JSON - if transformed_event.is_done() { - // For OpenAI client API, keep [DONE] as-is - // For Anthropic client API, it will be transformed via ProviderStreamResponseType - if matches!(client_api, SupportedAPIs::OpenAIChatCompletions(_)) { - // Keep the [DONE] marker as-is for OpenAI clients - transformed_event.sse_transform_buffer = "data: [DONE]".to_string(); - return Ok(transformed_event); - } - } - - // If has data, parse the data as a provider stream response (business logic layer) - if transformed_event.data.is_some() { - let data_str = transformed_event.data.as_ref().unwrap(); - let data_bytes = data_str.as_bytes(); - let transformed_response: ProviderStreamResponseType = - ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?; - - // Convert to SSE string explicitly to avoid type ambiguity - let sse_string: String = transformed_response.clone().into(); - transformed_event.sse_transform_buffer = sse_string; - transformed_event.provider_stream_response = Some(transformed_response); - } - - match (client_api, upstream_api) { - ( - SupportedAPIs::AnthropicMessagesAPI(_), - SupportedUpstreamAPIs::OpenAIChatCompletions(_), - ) => { - if let Some(provider_response) = &transformed_event.provider_stream_response { - if let Some(event_type) = provider_response.event_type() { - // This ensures the required Anthropic sequence: MessageStart → ContentBlockStart → ContentBlockDelta(s) - if event_type == "message_start" { - // Create ContentBlockStart event and format it using Into - let content_block_start = MessagesStreamEvent::ContentBlockStart { - index: 0, - content_block: crate::apis::anthropic::MessagesContentBlock::Text { - text: String::new(), - cache_control: None, - }, - }; - let content_block_start_sse: String = content_block_start.into(); - - // Format as proper SSE: MessageStart first, then ContentBlockStart - // The sse_transform_buffer already contains the properly formatted MessageStart - transformed_event.sse_transform_buffer = format!( - "{}{}", - transformed_event.sse_transform_buffer, content_block_start_sse, - ); - } else if event_type == "message_delta" { - // Create ContentBlockStop event and format it using Into - let content_block_stop = - MessagesStreamEvent::ContentBlockStop { index: 0 }; - let content_block_stop_sse: String = content_block_stop.into(); - - // Format as proper SSE: ContentBlockStop first, then MessageDelta - transformed_event.sse_transform_buffer = format!( - "{}{}", - content_block_stop_sse, transformed_event.sse_transform_buffer - ); - } - // For other event types, the sse_transform_buffer already has the correct format from Into - } - // If event_type is None, we just keep the data line as-is without an event line - // This handles cases where the transformation might not produce a valid event type - } - } - ( - SupportedAPIs::OpenAIChatCompletions(_), - SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - ) => { - if transformed_event.is_event_only() && transformed_event.event.is_some() { - transformed_event.sse_transform_buffer = format!("\n"); // suppress the event upstream for OpenAI - } - } - ( - SupportedAPIs::AnthropicMessagesAPI(_), - SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - ) => { - // When both client and upstream are Anthropic, suppress event-only lines - // because the data line transformation already includes the event line - if transformed_event.is_event_only() && transformed_event.event.is_some() { - transformed_event.sse_transform_buffer = String::new(); // suppress duplicate event line - } - } - _ => { - // Other combinations can be handled here as needed - } - } - - Ok(transformed_event) - } -} - -// TryFrom implementation to convert AWS Event Stream DecodedFrame to ProviderStreamResponseType -impl - TryFrom<( - &aws_smithy_eventstream::frame::DecodedFrame, - &SupportedAPIs, - &SupportedUpstreamAPIs, - )> for ProviderStreamResponseType -{ - type Error = Box; - - fn try_from( - (frame, client_api, upstream_api): ( - &aws_smithy_eventstream::frame::DecodedFrame, - &SupportedAPIs, - &SupportedUpstreamAPIs, - ), - ) -> Result { - use aws_smithy_eventstream::frame::DecodedFrame; - - match frame { - DecodedFrame::Complete(_) => { - // We have a complete frame - parse it based on upstream API - match (upstream_api, client_api) { - ( - SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), - SupportedAPIs::AnthropicMessagesAPI(_), - ) => { - // Parse the DecodedFrame into ConverseStreamEvent - let bedrock_event = - crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; - let anthropic_event: crate::apis::anthropic::MessagesStreamEvent = - bedrock_event.try_into()?; - - Ok(ProviderStreamResponseType::MessagesStreamEvent( - anthropic_event, - )) - } - ( - SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), - SupportedAPIs::OpenAIChatCompletions(_), - ) => { - // Parse the DecodedFrame into ConverseStreamEvent - let bedrock_event = - crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; - let openai_event: crate::apis::openai::ChatCompletionsStreamResponse = - bedrock_event.try_into()?; - Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( - openai_event, - )) - } - _ => Err("Unsupported API combination for event-stream decoding".into()), - } - } - DecodedFrame::Incomplete => { - Err("Cannot convert incomplete frame to provider response".into()) - } - } - } -} - #[derive(Debug)] pub struct ProviderResponseError { pub message: String, @@ -493,11 +255,9 @@ impl Error for ProviderResponseError { #[cfg(test)] mod tests { use super::*; - use crate::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; - use crate::apis::anthropic::AnthropicApi; use crate::apis::openai::OpenAIApi; - use crate::apis::sse::SseStreamIter; - use crate::clients::endpoints::SupportedAPIs; + use crate::apis::anthropic::AnthropicApi; + use crate::clients::endpoints::SupportedAPIsFromClient; use crate::providers::id::ProviderId; use serde_json::json; @@ -521,7 +281,7 @@ mod tests { let bytes = serde_json::to_vec(&resp).unwrap(); let result = ProviderResponseType::try_from(( bytes.as_slice(), - &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + &SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::OpenAI, )); assert!(result.is_ok()); @@ -550,7 +310,7 @@ mod tests { let bytes = serde_json::to_vec(&resp).unwrap(); let result = ProviderResponseType::try_from(( bytes.as_slice(), - &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), + &SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::Anthropic, )); assert!(result.is_ok()); @@ -584,7 +344,7 @@ mod tests { let bytes = serde_json::to_vec(&resp).unwrap(); let result = ProviderResponseType::try_from(( bytes.as_slice(), - &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), + &SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI, )); assert!(result.is_ok()); @@ -626,7 +386,7 @@ mod tests { let bytes = serde_json::to_vec(&resp).unwrap(); let result = ProviderResponseType::try_from(( bytes.as_slice(), - &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + &SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Anthropic, )); assert!(result.is_ok()); @@ -639,951 +399,4 @@ mod tests { _ => panic!("Expected ChatCompletionsResponse variant"), } } - - #[test] - fn test_sse_event_parsing() { - // Test valid SSE data line - let line = "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"; - let event: Result = line.parse(); - assert!(event.is_ok()); - let event = event.unwrap(); - assert_eq!( - event.data, - Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string()) - ); - - // Test conversion back to line using Display trait - let wire_format = event.to_string(); - assert_eq!( - wire_format, - "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n" - ); - - // Test [DONE] marker - should be valid SSE event - let done_line = "data: [DONE]"; - let done_result: Result = done_line.parse(); - assert!(done_result.is_ok()); - let done_event = done_result.unwrap(); - assert_eq!(done_event.data, Some("[DONE]".to_string())); - assert!(done_event.is_done()); // Test the helper method - - // Test non-DONE event - assert!(!event.is_done()); - - // Test empty data - should return error - let empty_line = "data: "; - let empty_result: Result = empty_line.parse(); - assert!(empty_result.is_err()); - - // Test non-data line - should return error - let comment_line = ": this is a comment"; - let comment_result: Result = comment_line.parse(); - assert!(comment_result.is_err()); - } - - #[test] - fn test_sse_event_serde() { - // Test serialization and deserialization with serde - let event = SseEvent { - data: Some(r#"{"id":"test","object":"chat.completion.chunk"}"#.to_string()), - event: None, - raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"} - - "# - .to_string(), - sse_transform_buffer: r#"data: {"id":"test","object":"chat.completion.chunk"} - - "# - .to_string(), - provider_stream_response: None, - }; - - // Test JSON serialization - raw_line should be skipped - let json = serde_json::to_string(&event).unwrap(); - assert!(json.contains("test")); - assert!(json.contains("chat.completion.chunk")); - assert!(!json.contains("raw_line")); // Should be excluded from serialization - - // Test JSON deserialization - let deserialized: SseEvent = serde_json::from_str(&json).unwrap(); - assert_eq!(deserialized.data, event.data); - assert_eq!(deserialized.raw_line, ""); // Should be empty since it's skipped - - // Test round trip for data field only - assert_eq!(event.data, deserialized.data); - } - - #[test] - fn test_sse_event_should_skip() { - // Test ping message should be skipped - let ping_event = SseEvent { - data: Some(r#"{"type": "ping"}"#.to_string()), - event: None, - raw_line: r#"data: {"type": "ping"}"#.to_string(), - sse_transform_buffer: r#"data: {"type": "ping"}"#.to_string(), - provider_stream_response: None, - }; - assert!(ping_event.should_skip()); - assert!(!ping_event.is_done()); - - // Test normal event should not be skipped - let normal_event = SseEvent { - data: Some(r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string()), - event: Some("content_block_delta".to_string()), - raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(), - sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"# - .to_string(), - provider_stream_response: None, - }; - assert!(!normal_event.should_skip()); - assert!(!normal_event.is_done()); - - // Test [DONE] event should not be skipped (but is handled separately) - let done_event = SseEvent { - data: Some("[DONE]".to_string()), - event: None, - raw_line: "data: [DONE]".to_string(), - sse_transform_buffer: "data: [DONE]".to_string(), - provider_stream_response: None, - }; - assert!(!done_event.should_skip()); - assert!(done_event.is_done()); - } - - #[test] - fn test_sse_stream_iter_filters_ping_messages() { - // Create test data with ping messages mixed in - let test_lines = vec![ - "data: {\"id\": \"msg1\", \"object\": \"chat.completion.chunk\"}".to_string(), - "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out - "data: {\"id\": \"msg2\", \"object\": \"chat.completion.chunk\"}".to_string(), - "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out - "data: [DONE]".to_string(), // This should end the stream - ]; - - let mut iter = SseStreamIter::new(test_lines.into_iter()); - - // First event should be msg1 (ping filtered out) - let event1 = iter.next().unwrap(); - assert!(event1.data.as_ref().unwrap().contains("msg1")); - assert!(!event1.should_skip()); - - // Second event should be msg2 (ping filtered out) - let event2 = iter.next().unwrap(); - assert!(event2.data.as_ref().unwrap().contains("msg2")); - assert!(!event2.should_skip()); - - // Third event should be [DONE] - let done_event = iter.next().unwrap(); - assert!(done_event.is_done()); - - // Iterator should end after [DONE] - assert!(iter.next().is_none()); - } - - #[test] - fn test_sse_stream_iter_handles_anthropic_events() { - // Create test data with Anthropic-style event/data pairs - let test_lines = vec![ - "event: message_start".to_string(), - "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\"}}".to_string(), - "event: content_block_delta".to_string(), - "data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}".to_string(), - "data: [DONE]".to_string(), - ]; - - let mut iter = SseStreamIter::new(test_lines.into_iter()); - - // First event should be the event: line - let event1 = iter.next().unwrap(); - assert!(event1.is_event_only()); - assert_eq!(event1.event, Some("message_start".to_string())); - assert_eq!(event1.data, None); - - // Second event should be the data: line - let event2 = iter.next().unwrap(); - assert!(!event2.is_event_only()); - assert_eq!(event2.event, None); - assert!(event2.data.as_ref().unwrap().contains("message_start")); - - // Third event should be another event: line - let event3 = iter.next().unwrap(); - assert!(event3.is_event_only()); - assert_eq!(event3.event, Some("content_block_delta".to_string())); - - // Fourth event should be the content delta data - let event4 = iter.next().unwrap(); - assert!(!event4.is_event_only()); - assert!(event4.data.as_ref().unwrap().contains("Hello")); - - // Fifth event should be [DONE] - let done_event = iter.next().unwrap(); - assert!(done_event.is_done()); - - // Iterator should end after [DONE] - assert!(iter.next().is_none()); - } - - #[test] - fn test_provider_stream_response_event_type() { - use crate::apis::anthropic::{MessagesContentDelta, MessagesStreamEvent}; - use crate::apis::openai::ChatCompletionsStreamResponse; - - // Test Anthropic event type - let anthropic_event = MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::TextDelta { - text: "Hello".to_string(), - }, - }; - let provider_type = ProviderStreamResponseType::MessagesStreamEvent(anthropic_event); - assert_eq!(provider_type.event_type(), Some("content_block_delta")); - - // Test OpenAI event type (should be None) - let openai_event = ChatCompletionsStreamResponse { - id: "test".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 123456789, - model: "gpt-4".to_string(), - choices: vec![], - usage: None, - system_fingerprint: None, - service_tier: None, - }; - let provider_type = ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_event); - assert_eq!(provider_type.event_type(), None); - } - - #[test] - fn test_done_marker_handled_in_stream_response_transformation() { - use crate::apis::anthropic::AnthropicApi; - - // Test that [DONE] marker is properly converted to MessageStop in the transformation layer - let done_bytes = b"[DONE]"; - let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions( - crate::apis::openai::OpenAIApi::ChatCompletions, - ); - - let result = ProviderStreamResponseType::try_from(( - done_bytes.as_slice(), - &client_api, - &upstream_api, - )); - assert!(result.is_ok()); - - if let Ok(ProviderStreamResponseType::MessagesStreamEvent(event)) = result { - // Verify it's a MessageStop event - assert_eq!(event.event_type(), Some("message_stop")); - assert!(matches!( - event, - crate::apis::anthropic::MessagesStreamEvent::MessageStop - )); - } else { - panic!("Expected MessagesStreamEvent::MessageStop"); - } - } - - #[test] - fn test_bedrock_event_stream_decoder_basic() { - use bytes::BytesMut; - - // Create a simple test with minimal data - let mut buffer = BytesMut::new(); - - // Add some arbitrary bytes (not a real event-stream frame, just for testing the decoder) - buffer.extend_from_slice(b"test data"); - - let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); - - // The decoder should return Incomplete for incomplete/invalid data - // This signals the caller to wait for more data - let result = decoder.decode_frame(); - assert!(result.is_some()); - assert!(matches!( - result.unwrap(), - aws_smithy_eventstream::frame::DecodedFrame::Incomplete - )); - - // Verify we can still access the buffer - assert!(decoder.has_remaining()); - } - - #[test] - fn test_bedrock_event_stream_decoder_with_real_frames() { - use bytes::BytesMut; - use std::fs; - use std::path::PathBuf; - - // Read the actual response.hex file from tests/e2e directory - let test_file = - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); - - // Only run this test if the file exists - if !test_file.exists() { - println!("Skipping test - response.hex not found"); - return; - } - - let response_data = fs::read(&test_file).unwrap(); - let mut buffer = BytesMut::from(&response_data[..]); - - let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); - let mut frame_count = 0; - - // Decode all frames - loop { - match decoder.decode_frame() { - Some(aws_smithy_eventstream::frame::DecodedFrame::Complete(message)) => { - frame_count += 1; - - // Verify we can access headers - let event_type = message - .headers() - .iter() - .find(|h| h.name().as_str() == ":event-type") - .and_then(|h| h.value().as_string().ok()); - - assert!(event_type.is_some(), "Frame should have :event-type header"); - } - Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { - // End of buffer, no more complete frames available - break; - } - None => { - // Decode error - panic!("Decode error encountered"); - } - } - } - - // We should have decoded multiple frames - assert!(frame_count > 0, "Should have decoded at least one frame"); - } - - #[test] - fn test_bedrock_event_stream_decoder_chunked_data() { - use bytes::BytesMut; - use std::fs; - use std::path::PathBuf; - - // Read the actual response.hex file - let test_file = - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); - - if !test_file.exists() { - println!("Skipping test - response.hex not found"); - return; - } - - let response_data = fs::read(&test_file).unwrap(); - - // Simulate chunked network arrivals with realistic chunk sizes - // Using varying chunk sizes to test partial frame handling - let mut buffer = BytesMut::new(); - let chunk_size_pattern = vec![500, 1000, 750, 1200, 800, 1500]; - let mut offset = 0; - let mut total_frames = 0; - let mut chunk_num = 0; - - // CRITICAL: Create ONE decoder and reuse it across chunks - // The MessageFrameDecoder maintains state about partial frames - let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); - - // Process all data in chunks - while offset < response_data.len() { - let chunk_size = chunk_size_pattern[chunk_num % chunk_size_pattern.len()]; - chunk_num += 1; - - let end = (offset + chunk_size).min(response_data.len()); - let chunk = &response_data[offset..end]; - - // Add new data to the buffer (accessing via buffer_mut()) - decoder.buffer_mut().extend_from_slice(chunk); - offset = end; - - // Process all available complete frames from this chunk - loop { - match decoder.decode_frame() { - Some(aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { - total_frames += 1; - } - Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { - // Need more data - wait for next chunk - break; - } - None => { - // Decode error - panic!("Decode error in chunked test"); - } - } - } - } - - assert!( - total_frames > 0, - "Should have decoded frames from chunked data" - ); - } - - #[test] - fn test_bedrock_decoded_frame_to_provider_response() { - test_bedrock_conversion(false); - } - - #[test] - #[ignore] // Run with: cargo test -- --ignored --nocapture - fn test_bedrock_decoded_frame_to_provider_response_verbose() { - test_bedrock_conversion(true); - } - - #[test] - fn test_bedrock_decoded_frame_with_tool_use() { - test_bedrock_conversion_with_tools(false); - } - - #[test] - #[ignore] // Run with: cargo test -- --ignored --nocapture - fn test_bedrock_decoded_frame_with_tool_use_verbose() { - test_bedrock_conversion_with_tools(true); - } - - fn test_bedrock_conversion(verbose: bool) { - use bytes::BytesMut; - use std::fs; - use std::path::PathBuf; - - // Read the actual response.hex file from tests/e2e directory - let test_file = - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); - - // Only run this test if the file exists - if !test_file.exists() { - println!("Skipping test - response.hex not found"); - return; - } - - let response_data = fs::read(&test_file).unwrap(); - let mut buffer = BytesMut::from(&response_data[..]); - - let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); - - let client_api = - SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( - crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, - ); - - let mut conversion_count = 0; - let mut message_start_seen = false; - - // Decode and convert frames - loop { - match decoder.decode_frame() { - Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { - // Convert DecodedFrame to ProviderStreamResponseType - let result = - ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); - - match result { - Ok(provider_response) => { - conversion_count += 1; - - // Verify we got a MessagesStreamEvent - assert!(matches!( - provider_response, - ProviderStreamResponseType::MessagesStreamEvent(_) - )); - - if verbose { - // Print the SSE string output - let sse_string: String = provider_response.clone().into(); - println!("{}", sse_string); - } - - // Check for MessageStart event - if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = - provider_response - { - if matches!( - event, - crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } - ) { - message_start_seen = true; - } - } - } - Err(e) => { - println!("Conversion error (frame {}): {}", conversion_count, e); - } - } - } - Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { - // End of buffer - break; - } - None => { - panic!("Decode error"); - } - } - } - - assert!( - conversion_count > 0, - "Should have converted at least one frame" - ); - assert!(message_start_seen, "Should have seen MessageStart event"); - } - - fn test_bedrock_conversion_with_tools(verbose: bool) { - use bytes::BytesMut; - use std::fs; - use std::path::PathBuf; - - // Read the actual response_with_tools.hex file from tests/e2e directory - let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("../../tests/e2e/response_with_tools.hex"); - - // Only run this test if the file exists - if !test_file.exists() { - println!("Skipping test - response_with_tools.hex not found"); - return; - } - - let response_data = fs::read(&test_file).unwrap(); - let mut buffer = BytesMut::from(&response_data[..]); - - let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); - - let client_api = - SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( - crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, - ); - - let mut conversion_count = 0; - let mut message_start_seen = false; - let mut content_block_start_seen = false; - let mut content_block_delta_tool_use_seen = false; - - // Decode and convert frames - loop { - match decoder.decode_frame() { - Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { - // Convert DecodedFrame to ProviderStreamResponseType - let result = - ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); - - match result { - Ok(provider_response) => { - conversion_count += 1; - - // Verify we got a MessagesStreamEvent - assert!(matches!( - provider_response, - ProviderStreamResponseType::MessagesStreamEvent(_) - )); - - if verbose { - // Print the SSE string output - let sse_string: String = provider_response.clone().into(); - println!("{}", sse_string); - } - - // Check for specific events related to tool use - if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = - provider_response - { - match event { - crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } => { - message_start_seen = true; - } - crate::apis::anthropic::MessagesStreamEvent::ContentBlockStart { .. } => { - content_block_start_seen = true; - } - crate::apis::anthropic::MessagesStreamEvent::ContentBlockDelta { delta, .. } => { - if matches!(delta, crate::apis::anthropic::MessagesContentDelta::InputJsonDelta { .. }) { - content_block_delta_tool_use_seen = true; - } - } - _ => {} - } - } - } - Err(e) => { - println!("Conversion error (frame {}): {}", conversion_count, e); - } - } - } - Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { - // End of buffer - break; - } - None => { - panic!("Decode error"); - } - } - } - - assert!( - conversion_count > 0, - "Should have converted at least one frame" - ); - assert!(message_start_seen, "Should have seen MessageStart event"); - assert!( - content_block_start_seen, - "Should have seen ContentBlockStart event for tool use" - ); - assert!( - content_block_delta_tool_use_seen, - "Should have seen ContentBlockDelta with ToolUseDelta" - ); - } - - #[test] - fn test_sse_event_transformation_openai_to_anthropic_message_start() { - use crate::apis::anthropic::AnthropicApi; - use crate::apis::openai::OpenAIApi; - - // Create an OpenAI stream response that represents a role start (which becomes message_start in Anthropic) - let openai_stream_chunk = json!({ - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1234567890, - "model": "gpt-4", - "choices": [{ - "index": 0, - "delta": {"role": "assistant"}, - "finish_reason": null - }] - }); - - // Create SSE event with this data - let sse_event = SseEvent { - data: Some(openai_stream_chunk.to_string()), - event: None, - raw_line: format!("data: {}", openai_stream_chunk.to_string()), - sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), - provider_stream_response: None, - }; - - let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify the transformation includes both message_start and content_block_start - let buffer = transformed.sse_transform_buffer; - assert!( - buffer.contains("event: message_start"), - "Should contain message_start event" - ); - assert!( - buffer.contains("event: content_block_start"), - "Should contain content_block_start event" - ); - - // Verify proper SSE format with event lines before data lines - assert!(buffer.find("event: message_start").unwrap() < buffer.find("data:").unwrap()); - assert!(buffer.find("content_block_start").is_some()); - } - - #[test] - fn test_sse_event_transformation_openai_to_anthropic_message_delta() { - use crate::apis::anthropic::AnthropicApi; - use crate::apis::openai::OpenAIApi; - - // Create an OpenAI stream response with finish_reason (which becomes message_delta in Anthropic) - let openai_stream_chunk = json!({ - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1234567890, - "model": "gpt-4", - "choices": [{ - "index": 0, - "delta": {}, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 10, - "completion_tokens": 25, - "total_tokens": 35 - } - }); - - // Create SSE event with this data - let sse_event = SseEvent { - data: Some(openai_stream_chunk.to_string()), - event: None, - raw_line: format!("data: {}", openai_stream_chunk.to_string()), - sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), - provider_stream_response: None, - }; - - let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify the transformation includes both content_block_stop and message_delta - let buffer = transformed.sse_transform_buffer; - assert!( - buffer.contains("event: content_block_stop"), - "Should contain content_block_stop event" - ); - assert!( - buffer.contains("event: message_delta"), - "Should contain message_delta event" - ); - - // Verify content_block_stop comes before message_delta - let stop_pos = buffer.find("content_block_stop").unwrap(); - let delta_pos = buffer.find("message_delta").unwrap(); - assert!( - stop_pos < delta_pos, - "content_block_stop should come before message_delta" - ); - } - - #[test] - fn test_sse_event_transformation_openai_to_anthropic_content_delta() { - use crate::apis::anthropic::AnthropicApi; - use crate::apis::openai::OpenAIApi; - - // Create an OpenAI stream response with content (which becomes content_block_delta in Anthropic) - let openai_stream_chunk = json!({ - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1234567890, - "model": "gpt-4", - "choices": [{ - "index": 0, - "delta": {"content": "Hello"}, - "finish_reason": null - }] - }); - - // Create SSE event with this data - let sse_event = SseEvent { - data: Some(openai_stream_chunk.to_string()), - event: None, - raw_line: format!("data: {}", openai_stream_chunk.to_string()), - sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), - provider_stream_response: None, - }; - - let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify the transformation is a content_block_delta (no extra events injected) - let buffer = transformed.sse_transform_buffer; - assert!( - buffer.contains("event: content_block_delta"), - "Should contain content_block_delta event" - ); - assert!( - !buffer.contains("content_block_start"), - "Should not inject content_block_start for content delta" - ); - assert!( - !buffer.contains("content_block_stop"), - "Should not inject content_block_stop for content delta" - ); - - // Verify the content is preserved - assert!(buffer.contains("Hello"), "Should preserve the content text"); - } - - #[test] - fn test_sse_event_transformation_anthropic_to_openai_suppresses_event_lines() { - use crate::apis::anthropic::AnthropicApi; - use crate::apis::openai::OpenAIApi; - - // Create an Anthropic event-only SSE line (no data) - let sse_event = SseEvent { - data: None, - event: Some("message_start".to_string()), - raw_line: "event: message_start".to_string(), - sse_transform_buffer: "event: message_start".to_string(), - provider_stream_response: None, - }; - - let client_api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify the event line is suppressed (replaced with just newline) - assert_eq!( - transformed.sse_transform_buffer, "\n", - "Event-only lines should be suppressed to newline for OpenAI" - ); - assert!( - transformed.is_event_only(), - "Should still be marked as event-only" - ); - } - - #[test] - fn test_sse_event_transformation_anthropic_to_openai_preserves_data() { - use crate::apis::anthropic::AnthropicApi; - use crate::apis::openai::OpenAIApi; - - // Create an Anthropic message_start event with data - let anthropic_event = json!({ - "type": "message_start", - "message": { - "id": "msg_123", - "type": "message", - "role": "assistant", - "content": [], - "model": "claude-3-sonnet", - "stop_reason": null, - "usage": {"input_tokens": 10, "output_tokens": 0} - } - }); - - let sse_event = SseEvent { - data: Some(anthropic_event.to_string()), - event: None, - raw_line: format!("data: {}", anthropic_event.to_string()), - sse_transform_buffer: format!("data: {}", anthropic_event.to_string()), - provider_stream_response: None, - }; - - let client_api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify data is transformed to OpenAI format - let buffer = transformed.sse_transform_buffer; - assert!(buffer.starts_with("data: "), "Should have data: prefix"); - assert!( - !buffer.contains("event:"), - "Should not have event: lines for OpenAI" - ); - - // Verify provider response was parsed - assert!(transformed.provider_stream_response.is_some()); - } - - #[test] - fn test_sse_event_transformation_no_change_for_matching_apis() { - use crate::apis::openai::OpenAIApi; - - // Create an OpenAI stream response - let openai_stream_chunk = json!({ - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1234567890, - "model": "gpt-4", - "choices": [{ - "index": 0, - "delta": {"content": "Hello"}, - "finish_reason": null - }] - }); - - let original_data = openai_stream_chunk.to_string(); - let sse_event = SseEvent { - data: Some(original_data.clone()), - event: None, - raw_line: format!("data: {}", original_data), - sse_transform_buffer: format!("data: {}\n\n", original_data), - provider_stream_response: None, - }; - - let client_api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify minimal transformation - just SSE formatting, no API conversion - let buffer = transformed.sse_transform_buffer; - assert!(buffer.starts_with("data: "), "Should preserve data: prefix"); - assert!(!buffer.contains("event:"), "Should not add event: lines"); - - // Verify provider response was parsed - assert!(transformed.provider_stream_response.is_some()); - } - - #[test] - fn test_sse_event_transformation_preserves_provider_response() { - use crate::apis::anthropic::AnthropicApi; - use crate::apis::openai::OpenAIApi; - - // Create an OpenAI stream response - let openai_stream_chunk = json!({ - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1234567890, - "model": "gpt-4", - "choices": [{ - "index": 0, - "delta": {"content": "Test"}, - "finish_reason": null - }] - }); - - let sse_event = SseEvent { - data: Some(openai_stream_chunk.to_string()), - event: None, - raw_line: format!("data: {}", openai_stream_chunk.to_string()), - sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), - provider_stream_response: None, - }; - - let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify provider_stream_response is populated - assert!( - transformed.provider_stream_response.is_some(), - "Should parse and store provider response" - ); - - // Verify we can access the provider response - let provider_response = transformed.provider_response(); - assert!( - provider_response.is_ok(), - "Should be able to access provider response" - ); - - // Verify the content delta is accessible - let content = provider_response.unwrap().content_delta(); - assert_eq!(content, Some("Test"), "Should preserve content delta"); - } } diff --git a/crates/hermesllm/src/providers/streaming_response.rs b/crates/hermesllm/src/providers/streaming_response.rs new file mode 100644 index 00000000..55e52f3d --- /dev/null +++ b/crates/hermesllm/src/providers/streaming_response.rs @@ -0,0 +1,1348 @@ +use serde::Serialize; +use std::convert::TryFrom; + +use crate::apis::openai::ChatCompletionsStreamResponse; +use crate::apis::openai_responses::ResponsesAPIStreamEvent; +use crate::apis::streaming_shapes::sse::SseEvent; +use crate::apis::amazon_bedrock::ConverseStreamEvent; +use crate::apis::anthropic::MessagesStreamEvent; +use crate::apis::streaming_shapes::sse::SseStreamBuffer; +use crate::apis::streaming_shapes::{ + anthropic_streaming_buffer::AnthropicMessagesStreamBuffer, + chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer, + passthrough_streaming_buffer::PassthroughStreamBuffer, + responses_api_streaming_buffer::ResponsesAPIStreamBuffer, + }; + +use crate::clients::endpoints::SupportedAPIsFromClient; +use crate::clients::endpoints::SupportedUpstreamAPIs; + +// ============================================================================ +// SSE STREAM BUFFER FACTORY +// ============================================================================ + +/// Check if streaming buffering is needed based on client and upstream API combination. +pub fn needs_buffering( + client_api: &SupportedAPIsFromClient, + upstream_api: &SupportedUpstreamAPIs, +) -> bool { + match (client_api, upstream_api) { + // Same APIs - no buffering needed + (SupportedAPIsFromClient::OpenAIChatCompletions(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => false, + (SupportedAPIsFromClient::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => false, + (SupportedAPIsFromClient::OpenAIResponsesAPI(_), SupportedUpstreamAPIs::OpenAIResponsesAPI(_)) => false, + + // Different APIs - buffering needed + _ => true, + } +} + +/// Factory pattern for creating SSE stream buffers based on client and upstream API combination. +/// # Example +/// ```ignore +/// use hermesllm::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; +/// use hermesllm::apis::streaming_shapes::sse::SseStreamBuffer; +/// +/// // Transformation needed: OpenAI upstream -> Anthropic client +/// let mut buffer = SseStreamBuffer::try_from((&client_api, &upstream_api))?; +/// +/// // Add transformed events +/// let transformed = SseEvent::try_from((raw_event, &client_api, &upstream_api))?; +/// buffer.add_transformed_event(transformed); +/// +/// // Flush to wire +/// let bytes = buffer.into_bytes(); +/// ``` +impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> + for SseStreamBuffer +{ + type Error = Box; + + fn try_from( + (client_api, upstream_api): (&SupportedAPIsFromClient, &SupportedUpstreamAPIs), + ) -> Result { + + // If APIs match, use passthrough - no buffering/transformation needed + if !needs_buffering(client_api, upstream_api) { + return Ok(SseStreamBuffer::Passthrough(PassthroughStreamBuffer::new())); + } + + // APIs differ - use appropriate buffer for client API + match client_api { + SupportedAPIsFromClient::OpenAIChatCompletions(_) => { + Ok(SseStreamBuffer::OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer::new())) + } + SupportedAPIsFromClient::AnthropicMessagesAPI(_) => { + Ok(SseStreamBuffer::AnthropicMessages(AnthropicMessagesStreamBuffer::new())) + } + SupportedAPIsFromClient::OpenAIResponsesAPI(_) => { + Ok(SseStreamBuffer::OpenAIResponses(ResponsesAPIStreamBuffer::new())) + } + } + } +} + +// ============================================================================ +// PROVIDER STREAM RESPONSE TYPES +// ============================================================================ + +#[derive(Serialize, Debug, Clone)] +#[serde(untagged)] +pub enum ProviderStreamResponseType { + ChatCompletionsStreamResponse(ChatCompletionsStreamResponse), + MessagesStreamEvent(MessagesStreamEvent), + ConverseStreamEvent(ConverseStreamEvent), + ResponseAPIStreamEvent(ResponsesAPIStreamEvent) +} + +pub trait ProviderStreamResponse: Send + Sync { + /// Get the content delta for this chunk + fn content_delta(&self) -> Option<&str>; + + /// Check if this is the final chunk in the stream + fn is_final(&self) -> bool; + + /// Get role information if available + fn role(&self) -> Option<&str>; + + /// Get event type for SSE streaming (used by Anthropic) + fn event_type(&self) -> Option<&str>; +} + +impl ProviderStreamResponse for ProviderStreamResponseType { + fn content_delta(&self) -> Option<&str> { + match self { + ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.content_delta(), + ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.content_delta(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.content_delta(), + ProviderStreamResponseType::ResponseAPIStreamEvent(_resp) => None, // ResponsesAPI does not have content deltas + } + } + + fn is_final(&self) -> bool { + match self { + ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.is_final(), + ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.is_final(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.is_final(), + ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.is_final(), + } + } + + fn role(&self) -> Option<&str> { + match self { + ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.role(), + ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.role(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.role(), + ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.role(), + } + } + + fn event_type(&self) -> Option<&str> { + match self { + ProviderStreamResponseType::ChatCompletionsStreamResponse(_resp) => None, // OpenAI doesn't use event types + ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.event_type(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.event_type(), // Bedrock doesn't use event types + ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.event_type(), + } + } + +} + +impl Into for ProviderStreamResponseType { + fn into(self) -> String { + match self { + ProviderStreamResponseType::MessagesStreamEvent(event) => { + // Use the Into implementation for proper SSE formatting with event lines + event.into() + } + ProviderStreamResponseType::ConverseStreamEvent(event) => { + // Use the Into implementation for proper SSE formatting with event lines + event.into() + } + ProviderStreamResponseType::ResponseAPIStreamEvent(event) => { + // Use the Into implementation for proper SSE formatting with event lines + event.into() + } + ProviderStreamResponseType::ChatCompletionsStreamResponse(_) => { + // For OpenAI, use simple data line format + let json = serde_json::to_string(&self).unwrap_or_default(); + format!("data: {}\n\n", json) + } + } + } +} + + +// Stream response transformation logic for client API compatibility +impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for ProviderStreamResponseType { + type Error = Box; + + fn try_from( + (bytes, client_api, upstream_api): (&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs), + ) -> Result { + // Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion + if bytes == b"[DONE]" && matches!(client_api, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) { + return Ok(ProviderStreamResponseType::MessagesStreamEvent( + crate::apis::anthropic::MessagesStreamEvent::MessageStop, + )); + } + match (upstream_api, client_api) { + // OpenAI upstream + ( + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), + ) => { + let resp = serde_json::from_slice(bytes)?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( + resp, + )) + } + ( + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), + ) => { + let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = + serde_json::from_slice(bytes)?; + let anthropic_resp = openai_resp.try_into()?; + Ok(ProviderStreamResponseType::MessagesStreamEvent( + anthropic_resp, + )) + } + ( + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = + serde_json::from_slice(bytes)?; + let responses_resp = openai_resp.try_into()?; + Ok(ProviderStreamResponseType::ResponseAPIStreamEvent( + responses_resp, + )) + } + + // OpenAI ResponsesAPI upstream + ( + SupportedUpstreamAPIs::OpenAIResponsesAPI(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + let resp = serde_json::from_slice(bytes)?; + Ok(ProviderStreamResponseType::ResponseAPIStreamEvent(resp)) + } + // Anthropic upstream + ( + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), + ) => { + let resp = serde_json::from_slice(bytes)?; + Ok(ProviderStreamResponseType::MessagesStreamEvent(resp)) + } + ( + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), + ) => { + let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = + serde_json::from_slice(bytes)?; + let openai_resp = anthropic_resp.try_into()?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( + openai_resp, + )) + } + + // Amazon Bedrock ConverseStream upstream + ( + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), + ) => { + let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent = + serde_json::from_slice(bytes)?; + let anthropic_resp = bedrock_resp.try_into()?; + Ok(ProviderStreamResponseType::MessagesStreamEvent( + anthropic_resp, + )) + } + ( + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + // Chain: Bedrock -> ChatCompletions -> ResponsesAPI + let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent = + serde_json::from_slice(bytes)?; + let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = bedrock_resp.try_into()?; + let responses_resp = chat_resp.try_into()?; + Ok(ProviderStreamResponseType::ResponseAPIStreamEvent( + responses_resp, + )) + } + _ => Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Unsupported API combination for response transformation", + ) + .into()), + } + } +} + +// TryFrom implementation to convert raw bytes to SseEvent with parsed provider response +impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for SseEvent { + type Error = Box; + + fn try_from( + (sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs), + ) -> Result { + // Create a new transformed event based on the original + let mut transformed_event = sse_event; + + // Handle [DONE] marker early - don't try to parse as JSON + if transformed_event.is_done() { + // For OpenAI client APIs (ChatCompletions and ResponsesAPI), keep [DONE] as-is + // For Anthropic client API, it will be transformed via ProviderStreamResponseType + if matches!(client_api, SupportedAPIsFromClient::OpenAIChatCompletions(_) | SupportedAPIsFromClient::OpenAIResponsesAPI(_)) { + // Keep the [DONE] marker as-is for OpenAI clients + transformed_event.sse_transformed_lines = "data: [DONE]".to_string(); + return Ok(transformed_event); + } + } + + // If has data, parse the data as a provider stream response (business logic layer) + if transformed_event.data.is_some() { + let data_str = transformed_event.data.as_ref().unwrap(); + let data_bytes = data_str.as_bytes(); + let transformed_response: ProviderStreamResponseType = + ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?; + + // Convert to SSE string explicitly to avoid type ambiguity + let sse_string: String = transformed_response.clone().into(); + transformed_event.sse_transformed_lines = sse_string; + transformed_event.provider_stream_response = Some(transformed_response); + } + + // Apply wire format adjustments for cross-API transformations + // Note: When APIs match (passthrough mode), these adjustments are skipped + // since PassthroughStreamBuffer will handle events as-is + if needs_buffering(client_api, upstream_api) { + match (client_api, upstream_api) { + ( + SupportedAPIsFromClient::OpenAIChatCompletions(_), + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + ) => { + // OpenAI clients don't expect separate event: lines + // Suppress upstream Anthropic event-only lines + if transformed_event.is_event_only() && transformed_event.event.is_some() { + transformed_event.sse_transformed_lines = format!("\n"); + } + } + _ => { + // Other cross-API combinations can be handled here as needed + } + } + } else { + // Passthrough mode: APIs match, no transformation needed + // For Anthropic and ResponsesAPI SSE formats, event-only lines are redundant because + // the Into implementation for MessagesStreamEvent and ResponsesAPIStreamEvent + // couples event and data lines together. We suppress event-only events to + // avoid duplicate event: lines in the output. + match (client_api, upstream_api) { + ( + SupportedAPIsFromClient::AnthropicMessagesAPI(_), + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + ) | ( + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + SupportedUpstreamAPIs::OpenAIResponsesAPI(_), + ) => { + if transformed_event.is_event_only() && transformed_event.event.is_some() { + // Mark as should-skip by clearing sse_transformed_lines + // The event line is already included when the data line is transformed + transformed_event.sse_transformed_lines = String::new(); + } + } + _ => { + // Other passthrough combinations (OpenAI ChatCompletions, etc.) don't have this issue + } + } + } + + Ok(transformed_event) + } +} + +// TryFrom implementation to convert AWS Event Stream DecodedFrame to ProviderStreamResponseType +impl + TryFrom<( + &aws_smithy_eventstream::frame::DecodedFrame, + &SupportedAPIsFromClient, + &SupportedUpstreamAPIs, + )> for ProviderStreamResponseType +{ + type Error = Box; + + fn try_from( + (frame, client_api, upstream_api): ( + &aws_smithy_eventstream::frame::DecodedFrame, + &SupportedAPIsFromClient, + &SupportedUpstreamAPIs, + ), + ) -> Result { + use aws_smithy_eventstream::frame::DecodedFrame; + + match frame { + DecodedFrame::Complete(_) => { + // We have a complete frame - parse it based on upstream API + match (upstream_api, client_api) { + ( + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), + ) => { + // Parse the DecodedFrame into ConverseStreamEvent + let bedrock_event = + crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; + let anthropic_event: crate::apis::anthropic::MessagesStreamEvent = + bedrock_event.try_into()?; + + Ok(ProviderStreamResponseType::MessagesStreamEvent( + anthropic_event, + )) + } + ( + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), + ) => { + // Parse the DecodedFrame into ConverseStreamEvent + let bedrock_event = + crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; + let openai_event: crate::apis::openai::ChatCompletionsStreamResponse = + bedrock_event.try_into()?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( + openai_event, + )) + } + ( + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + // Parse the DecodedFrame into ConverseStreamEvent + let bedrock_event = + crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; + let openai_chat_completions_event: crate::apis::openai::ChatCompletionsStreamResponse = + bedrock_event.try_into()?; + let openai_responses_api_event: crate::apis::openai_responses::ResponsesAPIStreamEvent = + openai_chat_completions_event.try_into()?; + + Ok(ProviderStreamResponseType::ResponseAPIStreamEvent( + openai_responses_api_event, + )) + } + _ => Err("Unsupported API combination for event-stream decoding".into()), + } + } + DecodedFrame::Incomplete => { + Err("Cannot convert incomplete frame to provider response".into()) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; + use crate::clients::endpoints::SupportedAPIsFromClient; + use crate::apis::streaming_shapes::sse::SseStreamIter; + use serde_json::json; + + #[test] + fn test_sse_event_parsing() { + // Test valid SSE data line + let line = "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"; + let event: Result = line.parse(); + assert!(event.is_ok()); + let event = event.unwrap(); + // The data field should contain only the JSON content, not the trailing newlines + assert_eq!( + event.data, + Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}".to_string()) + ); + + // Test conversion back to line using Display trait + // The sse_transformed_lines preserves the original format including trailing newlines + let wire_format = event.to_string(); + assert_eq!( + wire_format, + "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n" + ); + + // Test [DONE] marker - should be valid SSE event + let done_line = "data: [DONE]"; + let done_result: Result = done_line.parse(); + assert!(done_result.is_ok()); + let done_event = done_result.unwrap(); + assert_eq!(done_event.data, Some("[DONE]".to_string())); + assert!(done_event.is_done()); // Test the helper method + + // Test non-DONE event + assert!(!event.is_done()); + + // Test empty data - should return error + let empty_line = "data: "; + let empty_result: Result = empty_line.parse(); + assert!(empty_result.is_err()); + + // Test non-data line - should return error + let comment_line = ": this is a comment"; + let comment_result: Result = comment_line.parse(); + assert!(comment_result.is_err()); + } + + #[test] + fn test_sse_event_serde() { + // Test serialization and deserialization with serde + let event = SseEvent { + data: Some(r#"{"id":"test","object":"chat.completion.chunk"}"#.to_string()), + event: None, + raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"} + + "# + .to_string(), + sse_transformed_lines: r#"data: {"id":"test","object":"chat.completion.chunk"} + + "# + .to_string(), + provider_stream_response: None, + }; + + // Test JSON serialization - raw_line should be skipped + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("test")); + assert!(json.contains("chat.completion.chunk")); + assert!(!json.contains("raw_line")); // Should be excluded from serialization + + // Test JSON deserialization + let deserialized: SseEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.data, event.data); + assert_eq!(deserialized.raw_line, ""); // Should be empty since it's skipped + + // Test round trip for data field only + assert_eq!(event.data, deserialized.data); + } + + #[test] + fn test_sse_event_should_skip() { + // Test ping message should be skipped + let ping_event = SseEvent { + data: Some(r#"{"type": "ping"}"#.to_string()), + event: None, + raw_line: r#"data: {"type": "ping"}"#.to_string(), + sse_transformed_lines: r#"data: {"type": "ping"}"#.to_string(), + provider_stream_response: None, + }; + assert!(ping_event.should_skip()); + assert!(!ping_event.is_done()); + + // Test normal event should not be skipped + let normal_event = SseEvent { + data: Some(r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string()), + event: Some("content_block_delta".to_string()), + raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(), + sse_transformed_lines: r#"data: {"id": "test", "object": "chat.completion.chunk"}"# + .to_string(), + provider_stream_response: None, + }; + assert!(!normal_event.should_skip()); + assert!(!normal_event.is_done()); + + // Test [DONE] event should not be skipped (but is handled separately) + let done_event = SseEvent { + data: Some("[DONE]".to_string()), + event: None, + raw_line: "data: [DONE]".to_string(), + sse_transformed_lines: "data: [DONE]".to_string(), + provider_stream_response: None, + }; + assert!(!done_event.should_skip()); + assert!(done_event.is_done()); + } + + #[test] + fn test_sse_stream_iter_filters_ping_messages() { + // Create test data with ping messages mixed in + let test_lines = vec![ + "data: {\"id\": \"msg1\", \"object\": \"chat.completion.chunk\"}".to_string(), + "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out + "data: {\"id\": \"msg2\", \"object\": \"chat.completion.chunk\"}".to_string(), + "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out + "data: [DONE]".to_string(), // This should end the stream + ]; + + let mut iter = SseStreamIter::new(test_lines.into_iter()); + + // First event should be msg1 (ping filtered out) + let event1 = iter.next().unwrap(); + assert!(event1.data.as_ref().unwrap().contains("msg1")); + assert!(!event1.should_skip()); + + // Second event should be msg2 (ping filtered out) + let event2 = iter.next().unwrap(); + assert!(event2.data.as_ref().unwrap().contains("msg2")); + assert!(!event2.should_skip()); + + // Third event should be [DONE] + let done_event = iter.next().unwrap(); + assert!(done_event.is_done()); + + // Iterator should end after [DONE] + assert!(iter.next().is_none()); + } + + #[test] + fn test_sse_stream_iter_handles_anthropic_events() { + // Create test data with Anthropic-style event/data pairs + let test_lines = vec![ + "event: message_start".to_string(), + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\"}}".to_string(), + "event: content_block_delta".to_string(), + "data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}".to_string(), + "data: [DONE]".to_string(), + ]; + + let mut iter = SseStreamIter::new(test_lines.into_iter()); + + // First event should be the event: line + let event1 = iter.next().unwrap(); + assert!(event1.is_event_only()); + assert_eq!(event1.event, Some("message_start".to_string())); + assert_eq!(event1.data, None); + + // Second event should be the data: line + let event2 = iter.next().unwrap(); + assert!(!event2.is_event_only()); + assert_eq!(event2.event, None); + assert!(event2.data.as_ref().unwrap().contains("message_start")); + + // Third event should be another event: line + let event3 = iter.next().unwrap(); + assert!(event3.is_event_only()); + assert_eq!(event3.event, Some("content_block_delta".to_string())); + + // Fourth event should be the content delta data + let event4 = iter.next().unwrap(); + assert!(!event4.is_event_only()); + assert!(event4.data.as_ref().unwrap().contains("Hello")); + + // Fifth event should be [DONE] + let done_event = iter.next().unwrap(); + assert!(done_event.is_done()); + + // Iterator should end after [DONE] + assert!(iter.next().is_none()); + } + + #[test] + fn test_provider_stream_response_event_type() { + use crate::apis::anthropic::{MessagesContentDelta, MessagesStreamEvent}; + use crate::apis::openai::ChatCompletionsStreamResponse; + + // Test Anthropic event type + let anthropic_event = MessagesStreamEvent::ContentBlockDelta { + index: 0, + delta: MessagesContentDelta::TextDelta { + text: "Hello".to_string(), + }, + }; + let provider_type = ProviderStreamResponseType::MessagesStreamEvent(anthropic_event); + assert_eq!(provider_type.event_type(), Some("content_block_delta")); + + // Test OpenAI event type (should be None) + let openai_event = ChatCompletionsStreamResponse { + id: "test".to_string(), + object: Some("chat.completion.chunk".to_string()), + created: 123456789, + model: "gpt-4".to_string(), + choices: vec![], + usage: None, + system_fingerprint: None, + service_tier: None, + }; + let provider_type = ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_event); + assert_eq!(provider_type.event_type(), None); + } + + #[test] + fn test_done_marker_handled_in_stream_response_transformation() { + use crate::apis::anthropic::AnthropicApi; + + // Test that [DONE] marker is properly converted to MessageStop in the transformation layer + let done_bytes = b"[DONE]"; + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions( + crate::apis::openai::OpenAIApi::ChatCompletions, + ); + + let result = ProviderStreamResponseType::try_from(( + done_bytes.as_slice(), + &client_api, + &upstream_api, + )); + assert!(result.is_ok()); + + if let Ok(ProviderStreamResponseType::MessagesStreamEvent(event)) = result { + // Verify it's a MessageStop event + assert_eq!(event.event_type(), Some("message_stop")); + assert!(matches!( + event, + crate::apis::anthropic::MessagesStreamEvent::MessageStop + )); + } else { + panic!("Expected MessagesStreamEvent::MessageStop"); + } + } + + #[test] + fn test_bedrock_event_stream_decoder_basic() { + use bytes::BytesMut; + + // Create a simple test with minimal data + let mut buffer = BytesMut::new(); + + // Add some arbitrary bytes (not a real event-stream frame, just for testing the decoder) + buffer.extend_from_slice(b"test data"); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + // The decoder should return Incomplete for incomplete/invalid data + // This signals the caller to wait for more data + let result = decoder.decode_frame(); + assert!(result.is_some()); + assert!(matches!( + result.unwrap(), + aws_smithy_eventstream::frame::DecodedFrame::Incomplete + )); + + // Verify we can still access the buffer + assert!(decoder.has_remaining()); + } + + #[test] + fn test_bedrock_event_stream_decoder_with_real_frames() { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response.hex file from tests/e2e directory + let test_file = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); + + // Only run this test if the file exists + if !test_file.exists() { + println!("Skipping test - response.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + let mut buffer = BytesMut::from(&response_data[..]); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + let mut frame_count = 0; + + // Decode all frames + loop { + match decoder.decode_frame() { + Some(aws_smithy_eventstream::frame::DecodedFrame::Complete(message)) => { + frame_count += 1; + + // Verify we can access headers + let event_type = message + .headers() + .iter() + .find(|h| h.name().as_str() == ":event-type") + .and_then(|h| h.value().as_string().ok()); + + assert!(event_type.is_some(), "Frame should have :event-type header"); + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // End of buffer, no more complete frames available + break; + } + None => { + // Decode error + panic!("Decode error encountered"); + } + } + } + + // We should have decoded multiple frames + assert!(frame_count > 0, "Should have decoded at least one frame"); + } + + #[test] + fn test_bedrock_event_stream_decoder_chunked_data() { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response.hex file + let test_file = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); + + if !test_file.exists() { + println!("Skipping test - response.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + + // Simulate chunked network arrivals with realistic chunk sizes + // Using varying chunk sizes to test partial frame handling + let mut buffer = BytesMut::new(); + let chunk_size_pattern = vec![500, 1000, 750, 1200, 800, 1500]; + let mut offset = 0; + let mut total_frames = 0; + let mut chunk_num = 0; + + // CRITICAL: Create ONE decoder and reuse it across chunks + // The MessageFrameDecoder maintains state about partial frames + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + // Process all data in chunks + while offset < response_data.len() { + let chunk_size = chunk_size_pattern[chunk_num % chunk_size_pattern.len()]; + chunk_num += 1; + + let end = (offset + chunk_size).min(response_data.len()); + let chunk = &response_data[offset..end]; + + // Add new data to the buffer (accessing via buffer_mut()) + decoder.buffer_mut().extend_from_slice(chunk); + offset = end; + + // Process all available complete frames from this chunk + loop { + match decoder.decode_frame() { + Some(aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { + total_frames += 1; + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // Need more data - wait for next chunk + break; + } + None => { + // Decode error + panic!("Decode error in chunked test"); + } + } + } + } + + assert!( + total_frames > 0, + "Should have decoded frames from chunked data" + ); + } + + #[test] + fn test_bedrock_decoded_frame_to_provider_response() { + test_bedrock_conversion(false); + } + + #[test] + #[ignore] // Run with: cargo test -- --ignored --nocapture + fn test_bedrock_decoded_frame_to_provider_response_verbose() { + test_bedrock_conversion(true); + } + + #[test] + fn test_bedrock_decoded_frame_with_tool_use() { + test_bedrock_conversion_with_tools(false); + } + + #[test] + #[ignore] // Run with: cargo test -- --ignored --nocapture + fn test_bedrock_decoded_frame_with_tool_use_verbose() { + test_bedrock_conversion_with_tools(true); + } + + fn test_bedrock_conversion(verbose: bool) { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response.hex file from tests/e2e directory + let test_file = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); + + // Only run this test if the file exists + if !test_file.exists() { + println!("Skipping test - response.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + let mut buffer = BytesMut::from(&response_data[..]); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + let client_api = + SupportedAPIsFromClient::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( + crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, + ); + + let mut conversion_count = 0; + let mut message_start_seen = false; + + // Decode and convert frames + loop { + match decoder.decode_frame() { + Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { + // Convert DecodedFrame to ProviderStreamResponseType + let result = + ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); + + match result { + Ok(provider_response) => { + conversion_count += 1; + + // Verify we got a MessagesStreamEvent + assert!(matches!( + provider_response, + ProviderStreamResponseType::MessagesStreamEvent(_) + )); + + if verbose { + // Print the SSE string output + let sse_string: String = provider_response.clone().into(); + println!("{}", sse_string); + } + + // Check for MessageStart event + if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = + provider_response + { + if matches!( + event, + crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } + ) { + message_start_seen = true; + } + } + } + Err(e) => { + println!("Conversion error (frame {}): {}", conversion_count, e); + } + } + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // End of buffer + break; + } + None => { + panic!("Decode error"); + } + } + } + + assert!( + conversion_count > 0, + "Should have converted at least one frame" + ); + assert!(message_start_seen, "Should have seen MessageStart event"); + } + + fn test_bedrock_conversion_with_tools(verbose: bool) { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response_with_tools.hex file from tests/e2e directory + let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("../../tests/e2e/response_with_tools.hex"); + + // Only run this test if the file exists + if !test_file.exists() { + println!("Skipping test - response_with_tools.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + let mut buffer = BytesMut::from(&response_data[..]); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + let client_api = + SupportedAPIsFromClient::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( + crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, + ); + + let mut conversion_count = 0; + let mut message_start_seen = false; + let mut content_block_start_seen = false; + let mut content_block_delta_tool_use_seen = false; + + // Decode and convert frames + loop { + match decoder.decode_frame() { + Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { + // Convert DecodedFrame to ProviderStreamResponseType + let result = + ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); + + match result { + Ok(provider_response) => { + conversion_count += 1; + + // Verify we got a MessagesStreamEvent + assert!(matches!( + provider_response, + ProviderStreamResponseType::MessagesStreamEvent(_) + )); + + if verbose { + // Print the SSE string output + let sse_string: String = provider_response.clone().into(); + println!("{}", sse_string); + } + + // Check for specific events related to tool use + if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = + provider_response + { + match event { + crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } => { + message_start_seen = true; + } + crate::apis::anthropic::MessagesStreamEvent::ContentBlockStart { .. } => { + content_block_start_seen = true; + } + crate::apis::anthropic::MessagesStreamEvent::ContentBlockDelta { delta, .. } => { + if matches!(delta, crate::apis::anthropic::MessagesContentDelta::InputJsonDelta { .. }) { + content_block_delta_tool_use_seen = true; + } + } + _ => {} + } + } + } + Err(e) => { + println!("Conversion error (frame {}): {}", conversion_count, e); + } + } + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // End of buffer + break; + } + None => { + panic!("Decode error"); + } + } + } + + assert!( + conversion_count > 0, + "Should have converted at least one frame" + ); + assert!(message_start_seen, "Should have seen MessageStart event"); + assert!( + content_block_start_seen, + "Should have seen ContentBlockStart event for tool use" + ); + assert!( + content_block_delta_tool_use_seen, + "Should have seen ContentBlockDelta with ToolUseDelta" + ); + } + + + #[test] + fn test_sse_event_transformation_openai_to_anthropic_message_delta() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response with finish_reason (which becomes message_delta in Anthropic) + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 25, + "total_tokens": 35 + } + }); + + // Create SSE event with this data + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // NOTE: This test now verifies single-event transformation only. + // Multi-event injection (content_block_stop + message_delta) is now handled + // by AnthropicMessagesStreamBuffer, not by TryFrom transformation. + let buffer = transformed.sse_transformed_lines; + + // Verify the event was transformed to Anthropic format + // This should contain message_delta with stop_reason and usage + assert!( + buffer.contains("event: message_delta") || buffer.contains("\"type\":\"message_delta\""), + "Should contain message_delta in transformed event" + ); + + // Verify usage information is present + assert!( + buffer.contains("\"usage\""), + "Should contain usage information" + ); + } + + #[test] + fn test_sse_event_transformation_openai_to_anthropic_content_delta() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response with content (which becomes content_block_delta in Anthropic) + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"content": "Hello"}, + "finish_reason": null + }] + }); + + // Create SSE event with this data + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify the transformation is a content_block_delta (no extra events injected) + let buffer = transformed.sse_transformed_lines; + assert!( + buffer.contains("event: content_block_delta"), + "Should contain content_block_delta event" + ); + assert!( + !buffer.contains("content_block_start"), + "Should not inject content_block_start for content delta" + ); + assert!( + !buffer.contains("content_block_stop"), + "Should not inject content_block_stop for content delta" + ); + + // Verify the content is preserved + assert!(buffer.contains("Hello"), "Should preserve the content text"); + } + + #[test] + fn test_sse_event_transformation_anthropic_to_openai_suppresses_event_lines() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an Anthropic event-only SSE line (no data) + let sse_event = SseEvent { + data: None, + event: Some("message_start".to_string()), + raw_line: "event: message_start".to_string(), + sse_transformed_lines: "event: message_start".to_string(), + provider_stream_response: None, + }; + + let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify the event line is suppressed (replaced with just newline) + assert_eq!( + transformed.sse_transformed_lines, "\n", + "Event-only lines should be suppressed to newline for OpenAI" + ); + assert!( + transformed.is_event_only(), + "Should still be marked as event-only" + ); + } + + #[test] + fn test_sse_event_transformation_anthropic_to_openai_preserves_data() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an Anthropic message_start event with data + let anthropic_event = json!({ + "type": "message_start", + "message": { + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-sonnet", + "stop_reason": null, + "usage": {"input_tokens": 10, "output_tokens": 0} + } + }); + + let sse_event = SseEvent { + data: Some(anthropic_event.to_string()), + event: None, + raw_line: format!("data: {}", anthropic_event.to_string()), + sse_transformed_lines: format!("data: {}", anthropic_event.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify data is transformed to OpenAI format + let buffer = transformed.sse_transformed_lines; + assert!(buffer.starts_with("data: "), "Should have data: prefix"); + assert!( + !buffer.contains("event:"), + "Should not have event: lines for OpenAI" + ); + + // Verify provider response was parsed + assert!(transformed.provider_stream_response.is_some()); + } + + #[test] + fn test_sse_event_transformation_no_change_for_matching_apis() { + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"content": "Hello"}, + "finish_reason": null + }] + }); + + let original_data = openai_stream_chunk.to_string(); + let sse_event = SseEvent { + data: Some(original_data.clone()), + event: None, + raw_line: format!("data: {}", original_data), + sse_transformed_lines: format!("data: {}\n\n", original_data), + provider_stream_response: None, + }; + + let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify minimal transformation - just SSE formatting, no API conversion + let buffer = transformed.sse_transformed_lines; + assert!(buffer.starts_with("data: "), "Should preserve data: prefix"); + assert!(!buffer.contains("event:"), "Should not add event: lines"); + + // Verify provider response was parsed + assert!(transformed.provider_stream_response.is_some()); + } + + #[test] + fn test_sse_event_transformation_preserves_provider_response() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"content": "Test"}, + "finish_reason": null + }] + }); + + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify provider_stream_response is populated + assert!( + transformed.provider_stream_response.is_some(), + "Should parse and store provider response" + ); + + // Verify we can access the provider response + let provider_response = transformed.provider_response(); + assert!( + provider_response.is_ok(), + "Should be able to access provider response" + ); + + // Verify the content delta is accessible + let content = provider_response.unwrap().content_delta(); + assert_eq!(content, Some("Test"), "Should preserve content delta"); + } +} diff --git a/crates/hermesllm/src/transforms/mod.rs b/crates/hermesllm/src/transforms/mod.rs index 3fb4e397..ebb4bf20 100644 --- a/crates/hermesllm/src/transforms/mod.rs +++ b/crates/hermesllm/src/transforms/mod.rs @@ -11,11 +11,13 @@ pub mod lib; pub mod request; pub mod response; +pub mod response_streaming; // Re-export commonly used items for convenience pub use lib::*; pub use request::*; pub use response::*; +pub use response_streaming::*; // ============================================================================ // CONSTANTS diff --git a/crates/hermesllm/src/transforms/request/from_openai.rs b/crates/hermesllm/src/transforms/request/from_openai.rs index df4a9557..83f13fe8 100644 --- a/crates/hermesllm/src/transforms/request/from_openai.rs +++ b/crates/hermesllm/src/transforms/request/from_openai.rs @@ -12,6 +12,10 @@ use crate::apis::anthropic::{ use crate::apis::openai::{ ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType, }; + +use crate::apis::openai_responses::{ + ResponsesAPIRequest, InputContent, InputItem, InputParam, MessageRole, Modality, ReasoningEffort, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice +}; use crate::clients::TransformError; use crate::transforms::lib::ExtractText; use crate::transforms::lib::*; @@ -244,6 +248,202 @@ impl TryFrom for BedrockMessage { } } +impl TryFrom for ChatCompletionsRequest { + type Error = TransformError; + + fn try_from(req: ResponsesAPIRequest) -> Result { + + // Convert input to messages + let messages = match req.input { + InputParam::Text(text) => { + // Simple text input becomes a user message + vec![Message { + role: Role::User, + content: MessageContent::Text(text), + name: None, + tool_call_id: None, + tool_calls: None, + }] + } + InputParam::Items(items) => { + // Convert input items to messages + let mut converted_messages = Vec::new(); + + // Add instructions as system message if present + if let Some(instructions) = &req.instructions { + converted_messages.push(Message { + role: Role::System, + content: MessageContent::Text(instructions.clone()), + name: None, + tool_call_id: None, + tool_calls: None, + }); + } + + // Convert each input item + for item in items { + match item { + InputItem::Message(input_msg) => { + let role = match input_msg.role { + MessageRole::User => Role::User, + MessageRole::Assistant => Role::Assistant, + MessageRole::System => Role::System, + MessageRole::Developer => Role::System, // Map developer to system + }; + + // Convert content blocks + let content = if input_msg.content.len() == 1 { + // Single content item - check if it's simple text + match &input_msg.content[0] { + InputContent::InputText { text } => MessageContent::Text(text.clone()), + _ => { + // Convert to parts for non-text content + MessageContent::Parts( + input_msg.content.iter() + .filter_map(|c| match c { + InputContent::InputText { text } => { + Some(crate::apis::openai::ContentPart::Text { text: text.clone() }) + } + InputContent::InputImage { image_url, .. } => { + Some(crate::apis::openai::ContentPart::ImageUrl { + image_url: crate::apis::openai::ImageUrl { + url: image_url.clone(), + detail: None, + } + }) + } + InputContent::InputFile { .. } => None, // Skip files for now + InputContent::InputAudio { .. } => None, // Skip audio for now + }) + .collect() + ) + } + } + } else { + // Multiple content items - convert to parts + MessageContent::Parts( + input_msg.content.iter() + .filter_map(|c| match c { + InputContent::InputText { text } => { + Some(crate::apis::openai::ContentPart::Text { text: text.clone() }) + } + InputContent::InputImage { image_url, .. } => { + Some(crate::apis::openai::ContentPart::ImageUrl { + image_url: crate::apis::openai::ImageUrl { + url: image_url.clone(), + detail: None, + } + }) + } + InputContent::InputFile { .. } => None, // Skip files for now + InputContent::InputAudio { .. } => None, // Skip audio for now + }) + .collect() + ) + }; + + converted_messages.push(Message { + role, + content, + name: None, + tool_call_id: None, + tool_calls: None, + }); + } + } + } + + converted_messages + } + }; + + // Build the ChatCompletionsRequest + Ok(ChatCompletionsRequest { + model: req.model, + messages, + temperature: req.temperature, + top_p: req.top_p, + max_completion_tokens: req.max_output_tokens.map(|t| t as u32), + stream: req.stream, + metadata: req.metadata, + user: req.user, + store: req.store, + service_tier: req.service_tier, + top_logprobs: req.top_logprobs.map(|t| t as u32), + modalities: req.modalities.map(|mods| { + mods.into_iter().map(|m| { + match m { + Modality::Text => "text".to_string(), + Modality::Audio => "audio".to_string(), + } + }).collect() + }), + stream_options: req.stream_options.map(|opts| { + crate::apis::openai::StreamOptions { + include_usage: opts.include_usage, + } + }), + reasoning_effort: req.reasoning_effort.map(|effort| { + match effort { + ReasoningEffort::Low => "low".to_string(), + ReasoningEffort::Medium => "medium".to_string(), + ReasoningEffort::High => "high".to_string(), + } + }), + tools: req.tools.map(|tools| { + tools.into_iter().map(|tool| { + + // Only convert Function tools - other types are not supported in ChatCompletions + match tool { + ResponsesTool::Function { name, description, parameters, strict } => Ok(Tool { + tool_type: "function".to_string(), + function: crate::apis::openai::Function { + name, + description, + parameters: parameters.unwrap_or_else(|| serde_json::json!({ + "type": "object", + "properties": {} + })), + strict, + } + }), + ResponsesTool::FileSearch { .. } => Err(TransformError::UnsupportedConversion( + "FileSearch tool is not supported in ChatCompletions API. Only function tools are supported.".to_string() + )), + ResponsesTool::WebSearchPreview { .. } => Err(TransformError::UnsupportedConversion( + "WebSearchPreview tool is not supported in ChatCompletions API. Only function tools are supported.".to_string() + )), + ResponsesTool::CodeInterpreter => Err(TransformError::UnsupportedConversion( + "CodeInterpreter tool is not supported in ChatCompletions API. Only function tools are supported.".to_string() + )), + ResponsesTool::Computer { .. } => Err(TransformError::UnsupportedConversion( + "Computer tool is not supported in ChatCompletions API. Only function tools are supported.".to_string() + )), + } + }).collect::, _>>() + }).transpose()?, + tool_choice: req.tool_choice.map(|choice| { + match choice { + ResponsesToolChoice::String(s) => { + match s.as_str() { + "auto" => ToolChoice::Type(ToolChoiceType::Auto), + "required" => ToolChoice::Type(ToolChoiceType::Required), + "none" => ToolChoice::Type(ToolChoiceType::None), + _ => ToolChoice::Type(ToolChoiceType::Auto), // Default to auto for unknown strings + } + } + ResponsesToolChoice::Named { function, .. } => ToolChoice::Function { + choice_type: "function".to_string(), + function: crate::apis::openai::FunctionChoice { name: function.name } + } + } + }), + parallel_tool_calls: req.parallel_tool_calls, + ..Default::default() + }) + } +} + impl TryFrom for AnthropicMessagesRequest { type Error = TransformError; diff --git a/crates/hermesllm/src/transforms/response/to_anthropic.rs b/crates/hermesllm/src/transforms/response/to_anthropic.rs index 1c6ce238..0326fdb3 100644 --- a/crates/hermesllm/src/transforms/response/to_anthropic.rs +++ b/crates/hermesllm/src/transforms/response/to_anthropic.rs @@ -1,16 +1,11 @@ -use crate::apis::amazon_bedrock::{ - ContentBlockDelta, ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason, -}; +use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason}; use crate::apis::anthropic::{ - MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesResponse, - MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage, -}; -use crate::apis::openai::{ - ChatCompletionsResponse, ChatCompletionsStreamResponse, Role, ToolCallDelta, + MessagesContentBlock, MessagesResponse, + MessagesRole, MessagesStopReason, MessagesUsage, }; +use crate::apis::openai::ChatCompletionsResponse; use crate::clients::TransformError; use crate::transforms::lib::*; -use serde_json::Value; // ============================================================================ // STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for convenience @@ -120,289 +115,6 @@ impl TryFrom for MessagesResponse { } } -impl TryFrom for MessagesStreamEvent { - type Error = TransformError; - - fn try_from(resp: ChatCompletionsStreamResponse) -> Result { - if resp.choices.is_empty() { - return Ok(MessagesStreamEvent::Ping); - } - - let choice = &resp.choices[0]; - - // Handle final chunk with usage - let has_usage = resp.usage.is_some(); - if let Some(usage) = resp.usage { - if let Some(finish_reason) = &choice.finish_reason { - let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); - return Ok(MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: anthropic_stop_reason, - stop_sequence: None, - }, - usage: usage.into(), - }); - } - } - - // Handle role start - if let Some(Role::Assistant) = choice.delta.role { - return Ok(MessagesStreamEvent::MessageStart { - message: MessagesStreamMessage { - id: resp.id, - obj_type: "message".to_string(), - role: MessagesRole::Assistant, - content: vec![], - model: resp.model, - stop_reason: None, - stop_sequence: None, - usage: MessagesUsage { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }, - }); - } - - // Handle content delta - if let Some(content) = &choice.delta.content { - if !content.is_empty() { - return Ok(MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::TextDelta { - text: content.clone(), - }, - }); - } - } - - // Handle tool calls - if let Some(tool_calls) = &choice.delta.tool_calls { - return convert_tool_call_deltas(tool_calls.clone()); - } - - // Handle finish reason - generate MessageDelta only (MessageStop comes later) - if let Some(finish_reason) = &choice.finish_reason { - // If we have usage data, it was already handled above - // If not, we need to generate MessageDelta with default usage - if !has_usage { - let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); - return Ok(MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: anthropic_stop_reason, - stop_sequence: None, - }, - usage: MessagesUsage { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }); - } - // If usage was already handled above, we don't need to do anything more here - // MessageStop will be handled when [DONE] is encountered - } - - // Default to ping for unhandled cases - Ok(MessagesStreamEvent::Ping) - } -} - -impl Into for MessagesStreamEvent { - fn into(self) -> String { - let transformed_json = serde_json::to_string(&self).unwrap_or_default(); - let event_type = match &self { - MessagesStreamEvent::MessageStart { .. } => "message_start", - MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start", - MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta", - MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop", - MessagesStreamEvent::MessageDelta { .. } => "message_delta", - MessagesStreamEvent::MessageStop => "message_stop", - MessagesStreamEvent::Ping => "ping", - }; - - let event = format!("event: {}\n", event_type); - let data = format!("data: {}\n\n", transformed_json); - event + &data - } -} - -impl TryFrom for MessagesStreamEvent { - type Error = TransformError; - - fn try_from(event: ConverseStreamEvent) -> Result { - match event { - // MessageStart - convert to Anthropic MessageStart - ConverseStreamEvent::MessageStart(start_event) => { - let role = match start_event.role { - crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User, - crate::apis::amazon_bedrock::ConversationRole::Assistant => { - MessagesRole::Assistant - } - }; - - Ok(MessagesStreamEvent::MessageStart { - message: MessagesStreamMessage { - id: format!( - "bedrock-stream-{}", - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_nanos() - ), - obj_type: "message".to_string(), - role, - content: vec![], - model: "bedrock-model".to_string(), - stop_reason: None, - stop_sequence: None, - usage: MessagesUsage { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }, - }) - } - - // ContentBlockStart - convert to Anthropic ContentBlockStart - ConverseStreamEvent::ContentBlockStart(start_event) => { - // Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas - // Anthropic expects the same pattern, so we initialize with an empty input object - match start_event.start { - crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => { - Ok(MessagesStreamEvent::ContentBlockStart { - index: start_event.content_block_index as u32, - content_block: MessagesContentBlock::ToolUse { - id: tool_use.tool_use_id, - name: tool_use.name, - input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas - cache_control: None, - }, - }) - } - } - } - - // ContentBlockDelta - convert to Anthropic ContentBlockDelta - ConverseStreamEvent::ContentBlockDelta(delta_event) => { - let delta = match delta_event.delta { - ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text }, - ContentBlockDelta::ToolUse { tool_use } => { - MessagesContentDelta::InputJsonDelta { - partial_json: tool_use.input, - } - } - }; - - Ok(MessagesStreamEvent::ContentBlockDelta { - index: delta_event.content_block_index as u32, - delta, - }) - } - - // ContentBlockStop - convert to Anthropic ContentBlockStop - ConverseStreamEvent::ContentBlockStop(stop_event) => { - Ok(MessagesStreamEvent::ContentBlockStop { - index: stop_event.content_block_index as u32, - }) - } - - // MessageStop - convert to Anthropic MessageDelta with stop reason + MessageStop - ConverseStreamEvent::MessageStop(stop_event) => { - let anthropic_stop_reason = match stop_event.stop_reason { - StopReason::EndTurn => MessagesStopReason::EndTurn, - StopReason::ToolUse => MessagesStopReason::ToolUse, - StopReason::MaxTokens => MessagesStopReason::MaxTokens, - StopReason::StopSequence => MessagesStopReason::EndTurn, - StopReason::GuardrailIntervened => MessagesStopReason::Refusal, - StopReason::ContentFiltered => MessagesStopReason::Refusal, - }; - - // Return MessageDelta (MessageStop will be sent separately by the streaming handler) - Ok(MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: anthropic_stop_reason, - stop_sequence: None, - }, - usage: MessagesUsage { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }) - } - - // Metadata - convert usage information to MessageDelta - ConverseStreamEvent::Metadata(metadata_event) => { - Ok(MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: MessagesStopReason::EndTurn, - stop_sequence: None, - }, - usage: MessagesUsage { - input_tokens: metadata_event.usage.input_tokens, - output_tokens: metadata_event.usage.output_tokens, - cache_creation_input_tokens: metadata_event.usage.cache_write_input_tokens, - cache_read_input_tokens: metadata_event.usage.cache_read_input_tokens, - }, - }) - } - - // Exception events - convert to Ping (could be enhanced to return error events) - ConverseStreamEvent::InternalServerException(_) - | ConverseStreamEvent::ModelStreamErrorException(_) - | ConverseStreamEvent::ServiceUnavailableException(_) - | ConverseStreamEvent::ThrottlingException(_) - | ConverseStreamEvent::ValidationException(_) => { - // TODO: Consider adding proper error handling/events - Ok(MessagesStreamEvent::Ping) - } - } - } -} - -/// Convert tool call deltas to Anthropic stream events -fn convert_tool_call_deltas( - tool_calls: Vec, -) -> Result { - for tool_call in tool_calls { - if let Some(id) = &tool_call.id { - // Tool call start - if let Some(function) = &tool_call.function { - if let Some(name) = &function.name { - return Ok(MessagesStreamEvent::ContentBlockStart { - index: tool_call.index, - content_block: MessagesContentBlock::ToolUse { - id: id.clone(), - name: name.clone(), - input: Value::Object(serde_json::Map::new()), - cache_control: None, - }, - }); - } - } - } else if let Some(function) = &tool_call.function { - if let Some(arguments) = &function.arguments { - // Tool arguments delta - return Ok(MessagesStreamEvent::ContentBlockDelta { - index: tool_call.index, - delta: MessagesContentDelta::InputJsonDelta { - partial_json: arguments.clone(), - }, - }); - } - } - } - - // Fallback to ping if no valid tool call found - Ok(MessagesStreamEvent::Ping) -} /// Convert Bedrock Message to Anthropic content blocks /// diff --git a/crates/hermesllm/src/transforms/response/to_openai.rs b/crates/hermesllm/src/transforms/response/to_openai.rs index b44afc96..e26cc3b4 100644 --- a/crates/hermesllm/src/transforms/response/to_openai.rs +++ b/crates/hermesllm/src/transforms/response/to_openai.rs @@ -1,15 +1,13 @@ use crate::apis::amazon_bedrock::{ - ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason, + ConverseOutput, ConverseResponse, StopReason, }; use crate::apis::anthropic::{ - MessagesContentBlock, MessagesContentDelta, MessagesResponse, MessagesStopReason, - MessagesStreamEvent, MessagesUsage, + MessagesContentBlock, MessagesResponse, MessagesUsage, }; use crate::apis::openai::{ - ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason, - FunctionCallDelta, MessageContent, MessageDelta, ResponseMessage, Role, StreamChoice, - ToolCallDelta, Usage, + ChatCompletionsResponse, Choice, FinishReason, MessageContent, ResponseMessage, Role, Usage, }; +use crate::apis::openai_responses::ResponsesAPIResponse; use crate::clients::TransformError; use crate::transforms::lib::*; @@ -30,6 +28,163 @@ impl Into for MessagesUsage { } } +impl TryFrom for ResponsesAPIResponse { + type Error = TransformError; + + fn try_from(resp: ChatCompletionsResponse) -> Result { + use crate::apis::openai_responses::{ + IncompleteDetails, IncompleteReason, OutputContent, OutputItem, OutputItemStatus, + ResponseStatus, ResponseUsage, ResponsesAPIResponse, + }; + + // Convert the first choice's message to output items + let output = if let Some(choice) = resp.choices.first() { + let mut items = Vec::new(); + + // Create a message output item from the response message + let mut content = Vec::new(); + + // Add text content if present + if let Some(text) = &choice.message.content { + content.push(OutputContent::OutputText { + text: text.clone(), + annotations: vec![], + logprobs: None, + }); + } + + // Add audio content if present (audio is a Value, need to handle it carefully) + if let Some(audio) = &choice.message.audio { + // Audio is serde_json::Value, try to extract data and transcript + if let Some(audio_obj) = audio.as_object() { + let data = audio_obj + .get("data") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let transcript = audio_obj + .get("transcript") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + content.push(OutputContent::OutputAudio { data, transcript }); + } + } + + // Add refusal content if present + if let Some(refusal) = &choice.message.refusal { + content.push(OutputContent::Refusal { + refusal: refusal.clone(), + }); + } + + // Only add the message item if there's actual content (text, audio, or refusal) + // Don't add empty message items when there are only tool calls + if !content.is_empty() { + items.push(OutputItem::Message { + id: format!("msg_{}", resp.id), + status: OutputItemStatus::Completed, + role: match choice.message.role { + Role::User => "user".to_string(), + Role::Assistant => "assistant".to_string(), + Role::System => "system".to_string(), + Role::Tool => "tool".to_string(), + }, + content, + }); + } + + // Add tool calls as function call items if present + if let Some(tool_calls) = &choice.message.tool_calls { + for tool_call in tool_calls { + items.push(OutputItem::FunctionCall { + id: format!("func_{}", tool_call.id), + status: OutputItemStatus::Completed, + call_id: tool_call.id.clone(), + name: Some(tool_call.function.name.clone()), + arguments: Some(tool_call.function.arguments.clone()), + }); + } + } + + items + } else { + vec![] + }; + + // Convert finish_reason to status + let status = if let Some(choice) = resp.choices.first() { + match choice.finish_reason { + Some(FinishReason::Stop) => ResponseStatus::Completed, + Some(FinishReason::ToolCalls) => ResponseStatus::Completed, + Some(FinishReason::Length) => ResponseStatus::Incomplete, + Some(FinishReason::ContentFilter) => ResponseStatus::Failed, + _ => ResponseStatus::Completed, + } + } else { + ResponseStatus::Completed + }; + + // Convert usage + let usage = ResponseUsage { + input_tokens: resp.usage.prompt_tokens as i32, + output_tokens: resp.usage.completion_tokens as i32, + total_tokens: resp.usage.total_tokens as i32, + input_tokens_details: resp.usage.prompt_tokens_details.map(|details| { + crate::apis::openai_responses::TokenDetails { + cached_tokens: details.cached_tokens.unwrap_or(0) as i32, + } + }), + output_tokens_details: resp.usage.completion_tokens_details.map(|details| { + crate::apis::openai_responses::OutputTokenDetails { + reasoning_tokens: details.reasoning_tokens.unwrap_or(0) as i32, + } + }), + }; + + // Set incomplete_details if status is incomplete + let incomplete_details = if matches!(status, ResponseStatus::Incomplete) { + Some(IncompleteDetails { + reason: IncompleteReason::MaxOutputTokens, + }) + } else { + None + }; + + Ok(ResponsesAPIResponse { + id: resp.id, + object: "response".to_string(), + created_at: resp.created as i64, + status, + background: Some(false), + error: None, + incomplete_details, + instructions: None, + max_output_tokens: None, + max_tool_calls: None, + model: resp.model, + output, + usage: Some(usage), + parallel_tool_calls: true, + conversation: None, + previous_response_id: None, + tools: vec![], + tool_choice: "auto".to_string(), + temperature: 1.0, + top_p: 1.0, + metadata: resp.metadata.unwrap_or_default(), + truncation: None, + reasoning: None, + store: None, + text: None, + audio: None, + modalities: None, + service_tier: resp.service_tier, + top_logprobs: None, + }) + } +} + + impl TryFrom for ChatCompletionsResponse { type Error = TransformError; @@ -173,416 +328,6 @@ impl TryFrom for ChatCompletionsResponse { } } -// ============================================================================ -// STREAMING TRANSFORMATIONS -// ============================================================================ - -impl TryFrom for ChatCompletionsStreamResponse { - type Error = TransformError; - - fn try_from(event: MessagesStreamEvent) -> Result { - match event { - MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk( - &message.id, - &message.model, - MessageDelta { - role: Some(Role::Assistant), - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )), - - MessagesStreamEvent::ContentBlockStart { content_block, .. } => { - convert_content_block_start(content_block) - } - - MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta), - - MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()), - - MessagesStreamEvent::MessageDelta { delta, usage } => { - let finish_reason: Option = Some(delta.stop_reason.into()); - let openai_usage: Option = Some(usage.into()); - - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - finish_reason, - openai_usage, - )) - } - - MessagesStreamEvent::MessageStop => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - Some(FinishReason::Stop), - None, - )), - - MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse { - id: "stream".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: current_timestamp(), - model: "unknown".to_string(), - choices: vec![], - usage: None, - system_fingerprint: None, - service_tier: None, - }), - } - } -} - -impl TryFrom for ChatCompletionsStreamResponse { - type Error = TransformError; - - fn try_from(event: ConverseStreamEvent) -> Result { - match event { - ConverseStreamEvent::MessageStart(start_event) => { - let role = match start_event.role { - crate::apis::amazon_bedrock::ConversationRole::User => Role::User, - crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant, - }; - - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: Some(role), - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )) - } - - ConverseStreamEvent::ContentBlockStart(start_event) => { - use crate::apis::amazon_bedrock::ContentBlockStart; - - match start_event.start { - ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: start_event.content_block_index as u32, - id: Some(tool_use.tool_use_id), - call_type: Some("function".to_string()), - function: Some(FunctionCallDelta { - name: Some(tool_use.name), - arguments: Some("".to_string()), - }), - }]), - }, - None, - None, - )), - } - } - - ConverseStreamEvent::ContentBlockDelta(delta_event) => { - use crate::apis::amazon_bedrock::ContentBlockDelta; - - match delta_event.delta { - ContentBlockDelta::Text { text } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: Some(text), - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )), - ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: delta_event.content_block_index as u32, - id: None, - call_type: None, - function: Some(FunctionCallDelta { - name: None, - arguments: Some(tool_use.input), - }), - }]), - }, - None, - None, - )), - } - } - - ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()), - - ConverseStreamEvent::MessageStop(stop_event) => { - let finish_reason = match stop_event.stop_reason { - StopReason::EndTurn => FinishReason::Stop, - StopReason::ToolUse => FinishReason::ToolCalls, - StopReason::MaxTokens => FinishReason::Length, - StopReason::StopSequence => FinishReason::Stop, - StopReason::GuardrailIntervened => FinishReason::ContentFilter, - StopReason::ContentFiltered => FinishReason::ContentFilter, - }; - - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - Some(finish_reason), - None, - )) - } - - ConverseStreamEvent::Metadata(metadata_event) => { - let usage = Usage { - prompt_tokens: metadata_event.usage.input_tokens, - completion_tokens: metadata_event.usage.output_tokens, - total_tokens: metadata_event.usage.total_tokens, - prompt_tokens_details: None, - completion_tokens_details: None, - }; - - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - Some(usage), - )) - } - - // Error events - convert to empty chunks (errors should be handled elsewhere) - ConverseStreamEvent::InternalServerException(_) - | ConverseStreamEvent::ModelStreamErrorException(_) - | ConverseStreamEvent::ServiceUnavailableException(_) - | ConverseStreamEvent::ThrottlingException(_) - | ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()), - } - } -} - -/// Convert content block start to OpenAI chunk -fn convert_content_block_start( - content_block: MessagesContentBlock, -) -> Result { - match content_block { - MessagesContentBlock::Text { .. } => { - // No immediate output for text block start - Ok(create_empty_openai_chunk()) - } - MessagesContentBlock::ToolUse { id, name, .. } - | MessagesContentBlock::ServerToolUse { id, name, .. } - | MessagesContentBlock::McpToolUse { id, name, .. } => { - // Tool use start → OpenAI chunk with tool_calls - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: 0, - id: Some(id), - call_type: Some("function".to_string()), - function: Some(FunctionCallDelta { - name: Some(name), - arguments: Some("".to_string()), - }), - }]), - }, - None, - None, - )) - } - _ => Err(TransformError::UnsupportedContent( - "Unsupported content block type in stream start".to_string(), - )), - } -} - -/// Convert content delta to OpenAI chunk -fn convert_content_delta( - delta: MessagesContentDelta, -) -> Result { - match delta { - MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: Some(text), - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )), - MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: Some(format!("thinking: {}", thinking)), - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )), - MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: 0, - id: None, - call_type: None, - function: Some(FunctionCallDelta { - name: None, - arguments: Some(partial_json), - }), - }]), - }, - None, - None, - )), - } -} - -/// Helper to create OpenAI streaming chunk -fn create_openai_chunk( - id: &str, - model: &str, - delta: MessageDelta, - finish_reason: Option, - usage: Option, -) -> ChatCompletionsStreamResponse { - ChatCompletionsStreamResponse { - id: id.to_string(), - object: Some("chat.completion.chunk".to_string()), - created: current_timestamp(), - model: model.to_string(), - choices: vec![StreamChoice { - index: 0, - delta, - finish_reason, - logprobs: None, - }], - usage, - system_fingerprint: None, - service_tier: None, - } -} - -/// Helper to create empty OpenAI streaming chunk -fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse { - create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - ) -} - -/// Convert Anthropic content blocks to OpenAI message content -fn convert_anthropic_content_to_openai( - content: &[MessagesContentBlock], -) -> Result { - let mut text_parts = Vec::new(); - - for block in content { - match block { - MessagesContentBlock::Text { text, .. } => { - text_parts.push(text.clone()); - } - MessagesContentBlock::Thinking { thinking, .. } => { - text_parts.push(format!("thinking: {}", thinking)); - } - _ => { - // Skip other content types for basic text conversion - continue; - } - } - } - - Ok(MessageContent::Text(text_parts.join("\n"))) -} - -// Stop Reason Conversions -impl Into for MessagesStopReason { - fn into(self) -> FinishReason { - match self { - MessagesStopReason::EndTurn => FinishReason::Stop, - MessagesStopReason::MaxTokens => FinishReason::Length, - MessagesStopReason::StopSequence => FinishReason::Stop, - MessagesStopReason::ToolUse => FinishReason::ToolCalls, - MessagesStopReason::PauseTurn => FinishReason::Stop, - MessagesStopReason::Refusal => FinishReason::ContentFilter, - } - } -} - /// Convert Bedrock Message to OpenAI content and tool calls /// This function extracts text content and tool calls from a Bedrock message fn convert_bedrock_message_to_openai( @@ -627,6 +372,31 @@ fn convert_bedrock_message_to_openai( Ok((content, tool_calls)) } +/// Convert Anthropic content blocks to OpenAI message content +fn convert_anthropic_content_to_openai( + content: &[MessagesContentBlock], +) -> Result { + let mut text_parts = Vec::new(); + + for block in content { + match block { + MessagesContentBlock::Text { text, .. } => { + text_parts.push(text.clone()); + } + MessagesContentBlock::Thinking { thinking, .. } => { + text_parts.push(format!("thinking: {}", thinking)); + } + _ => { + // Skip other content types for basic text conversion + continue; + } + } + } + + Ok(MessageContent::Text(text_parts.join("\n"))) +} + + #[cfg(test)] mod tests { use super::*; @@ -1166,4 +936,212 @@ mod tests { assert!(content.contains("Here's the analysis:")); // Note: Image blocks are not converted to text in the current implementation } + + #[test] + fn test_chat_completions_to_responses_api_basic() { + use crate::apis::openai_responses::{OutputContent, OutputItem, ResponsesAPIResponse}; + + let chat_response = ChatCompletionsResponse { + id: "chatcmpl-123".to_string(), + object: Some("chat.completion".to_string()), + created: 1677652288, + model: "gpt-4".to_string(), + choices: vec![Choice { + index: 0, + message: crate::apis::openai::ResponseMessage { + role: Role::Assistant, + content: Some("Hello! How can I help you?".to_string()), + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls: None, + }, + finish_reason: Some(FinishReason::Stop), + logprobs: None, + }], + usage: Usage { + prompt_tokens: 10, + completion_tokens: 20, + total_tokens: 30, + prompt_tokens_details: None, + completion_tokens_details: None, + }, + system_fingerprint: None, + service_tier: Some("default".to_string()), + metadata: None, + }; + + let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap(); + + assert_eq!(responses_api.id, "chatcmpl-123"); + assert_eq!(responses_api.object, "response"); + assert_eq!(responses_api.model, "gpt-4"); + + // Check usage conversion + let usage = responses_api.usage.unwrap(); + assert_eq!(usage.input_tokens, 10); + assert_eq!(usage.output_tokens, 20); + assert_eq!(usage.total_tokens, 30); + + // Check output items + assert_eq!(responses_api.output.len(), 1); + match &responses_api.output[0] { + OutputItem::Message { + role, + content, + .. + } => { + assert_eq!(role, "assistant"); + assert_eq!(content.len(), 1); + match &content[0] { + OutputContent::OutputText { text, .. } => { + assert_eq!(text, "Hello! How can I help you?"); + } + _ => panic!("Expected OutputText content"), + } + } + _ => panic!("Expected Message output item"), + } + } + + #[test] + fn test_chat_completions_to_responses_api_with_tool_calls() { + use crate::apis::openai::{FunctionCall, ToolCall}; + use crate::apis::openai_responses::{OutputItem, ResponsesAPIResponse}; + + let chat_response = ChatCompletionsResponse { + id: "chatcmpl-456".to_string(), + object: Some("chat.completion".to_string()), + created: 1677652300, + model: "gpt-4".to_string(), + choices: vec![Choice { + index: 0, + message: crate::apis::openai::ResponseMessage { + role: Role::Assistant, + content: Some("Let me check the weather.".to_string()), + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls: Some(vec![ToolCall { + id: "call_abc123".to_string(), + call_type: "function".to_string(), + function: FunctionCall { + name: "get_weather".to_string(), + arguments: r#"{"location":"San Francisco"}"#.to_string(), + }, + }]), + }, + finish_reason: Some(FinishReason::ToolCalls), + logprobs: None, + }], + usage: Usage { + prompt_tokens: 15, + completion_tokens: 25, + total_tokens: 40, + prompt_tokens_details: None, + completion_tokens_details: None, + }, + system_fingerprint: None, + service_tier: None, + metadata: None, + }; + + let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap(); + + // Should have 2 output items: message + function call + assert_eq!(responses_api.output.len(), 2); + + // Check message item + match &responses_api.output[0] { + OutputItem::Message { content, .. } => { + assert_eq!(content.len(), 1); + } + _ => panic!("Expected Message output item"), + } + + // Check function call item + match &responses_api.output[1] { + OutputItem::FunctionCall { + call_id, + name, + arguments, + .. + } => { + assert_eq!(call_id, "call_abc123"); + assert_eq!(name.as_ref().unwrap(), "get_weather"); + assert!(arguments.as_ref().unwrap().contains("San Francisco")); + } + _ => panic!("Expected FunctionCall output item"), + } + } + + #[test] + fn test_chat_completions_to_responses_api_tool_calls_only() { + use crate::apis::openai::{FunctionCall, ToolCall}; + use crate::apis::openai_responses::{OutputItem, ResponsesAPIResponse}; + + // Test the real-world case where content is null and there are only tool calls + let chat_response = ChatCompletionsResponse { + id: "chatcmpl-789".to_string(), + object: Some("chat.completion".to_string()), + created: 1764023939, + model: "gpt-4o-2024-08-06".to_string(), + choices: vec![Choice { + index: 0, + message: crate::apis::openai::ResponseMessage { + role: Role::Assistant, + content: None, // No text content, only tool calls + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls: Some(vec![ToolCall { + id: "call_oJBtqTJmRfBGlFS55QhMfUUV".to_string(), + call_type: "function".to_string(), + function: FunctionCall { + name: "get_weather".to_string(), + arguments: r#"{"location":"San Francisco, CA"}"#.to_string(), + }, + }]), + }, + finish_reason: Some(FinishReason::ToolCalls), + logprobs: None, + }], + usage: Usage { + prompt_tokens: 84, + completion_tokens: 17, + total_tokens: 101, + prompt_tokens_details: None, + completion_tokens_details: None, + }, + system_fingerprint: Some("fp_7eeb46f068".to_string()), + service_tier: Some("default".to_string()), + metadata: None, + }; + + let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap(); + + // Should have only 1 output item: function call (no empty message item) + assert_eq!(responses_api.output.len(), 1); + + // Check function call item + match &responses_api.output[0] { + OutputItem::FunctionCall { + call_id, + name, + arguments, + .. + } => { + assert_eq!(call_id, "call_oJBtqTJmRfBGlFS55QhMfUUV"); + assert_eq!(name.as_ref().unwrap(), "get_weather"); + assert!(arguments.as_ref().unwrap().contains("San Francisco, CA")); + } + _ => panic!("Expected FunctionCall output item as first item"), + } + + // Verify status is Completed for tool_calls finish reason + assert!(matches!(responses_api.status, crate::apis::openai_responses::ResponseStatus::Completed)); + } } diff --git a/crates/hermesllm/src/transforms/response_streaming/mod.rs b/crates/hermesllm/src/transforms/response_streaming/mod.rs new file mode 100644 index 00000000..fb06cce3 --- /dev/null +++ b/crates/hermesllm/src/transforms/response_streaming/mod.rs @@ -0,0 +1,2 @@ +pub mod to_anthropic_streaming; +pub mod to_openai_streaming; diff --git a/crates/hermesllm/src/transforms/response_streaming/to_anthropic_streaming.rs b/crates/hermesllm/src/transforms/response_streaming/to_anthropic_streaming.rs new file mode 100644 index 00000000..b8cac631 --- /dev/null +++ b/crates/hermesllm/src/transforms/response_streaming/to_anthropic_streaming.rs @@ -0,0 +1,281 @@ +use crate::apis::amazon_bedrock::{ + ContentBlockDelta, ConverseStreamEvent, +}; +use crate::apis::anthropic::{ + MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, + MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage, +}; +use crate::apis::openai::{ ChatCompletionsStreamResponse, ToolCallDelta, +}; +use crate::clients::TransformError; +use serde_json::Value; + +impl TryFrom for MessagesStreamEvent { + type Error = TransformError; + + fn try_from(resp: ChatCompletionsStreamResponse) -> Result { + if resp.choices.is_empty() { + return Ok(MessagesStreamEvent::Ping); + } + + let choice = &resp.choices[0]; + + // Handle final chunk with usage + let has_usage = resp.usage.is_some(); + if let Some(usage) = resp.usage { + if let Some(finish_reason) = &choice.finish_reason { + let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); + return Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: anthropic_stop_reason, + stop_sequence: None, + }, + usage: usage.into(), + }); + } + } + + // NOTE: We do NOT emit MessageStart here anymore! + // The AnthropicMessagesStreamBuffer will inject message_start and content_block_start + // when it sees the first content_block_delta. This solves the problem where OpenAI + // sends both role and content in the same chunk - we can only return one event here, + // so we prioritize the content and let the buffer handle lifecycle events. + + // Handle content delta (even if role is present in the same chunk) + if let Some(content) = &choice.delta.content { + if !content.is_empty() { + return Ok(MessagesStreamEvent::ContentBlockDelta { + index: 0, + delta: MessagesContentDelta::TextDelta { + text: content.clone(), + }, + }); + } + } + + // Handle tool calls + if let Some(tool_calls) = &choice.delta.tool_calls { + return convert_tool_call_deltas(tool_calls.clone()); + } + + // Handle finish reason - generate MessageDelta only (MessageStop comes later) + if let Some(finish_reason) = &choice.finish_reason { + // If we have usage data, it was already handled above + // If not, we need to generate MessageDelta with default usage + if !has_usage { + let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); + return Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: anthropic_stop_reason, + stop_sequence: None, + }, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }); + } + // If usage was already handled above, we don't need to do anything more here + // MessageStop will be handled when [DONE] is encountered + } + + // Default to ping for unhandled cases + Ok(MessagesStreamEvent::Ping) + } +} + +impl Into for MessagesStreamEvent { + fn into(self) -> String { + let transformed_json = serde_json::to_string(&self).unwrap_or_default(); + let event_type = match &self { + MessagesStreamEvent::MessageStart { .. } => "message_start", + MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start", + MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta", + MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop", + MessagesStreamEvent::MessageDelta { .. } => "message_delta", + MessagesStreamEvent::MessageStop => "message_stop", + MessagesStreamEvent::Ping => "ping", + }; + + let event = format!("event: {}\n", event_type); + let data = format!("data: {}\n\n", transformed_json); + event + &data + } +} + +impl TryFrom for MessagesStreamEvent { + type Error = TransformError; + + fn try_from(event: ConverseStreamEvent) -> Result { + match event { + // MessageStart - convert to Anthropic MessageStart + ConverseStreamEvent::MessageStart(start_event) => { + let role = match start_event.role { + crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => { + MessagesRole::Assistant + } + }; + + Ok(MessagesStreamEvent::MessageStart { + message: MessagesStreamMessage { + id: format!( + "bedrock-stream-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() + ), + obj_type: "message".to_string(), + role, + content: vec![], + model: "bedrock-model".to_string(), + stop_reason: None, + stop_sequence: None, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }, + }) + } + + // ContentBlockStart - convert to Anthropic ContentBlockStart + ConverseStreamEvent::ContentBlockStart(start_event) => { + // Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas + // Anthropic expects the same pattern, so we initialize with an empty input object + match start_event.start { + crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => { + Ok(MessagesStreamEvent::ContentBlockStart { + index: start_event.content_block_index as u32, + content_block: MessagesContentBlock::ToolUse { + id: tool_use.tool_use_id, + name: tool_use.name, + input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas + cache_control: None, + }, + }) + } + } + } + + // ContentBlockDelta - convert to Anthropic ContentBlockDelta + ConverseStreamEvent::ContentBlockDelta(delta_event) => { + let delta = match delta_event.delta { + ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text }, + ContentBlockDelta::ToolUse { tool_use } => { + MessagesContentDelta::InputJsonDelta { + partial_json: tool_use.input, + } + } + }; + + Ok(MessagesStreamEvent::ContentBlockDelta { + index: delta_event.content_block_index as u32, + delta, + }) + } + + // ContentBlockStop - convert to Anthropic ContentBlockStop + ConverseStreamEvent::ContentBlockStop(stop_event) => { + Ok(MessagesStreamEvent::ContentBlockStop { + index: stop_event.content_block_index as u32, + }) + } + + // MessageStop - convert to Anthropic MessageDelta with stop reason + // Note: Bedrock sends Metadata separately with usage info, creating a second MessageDelta + // The client should merge these or use the final one with complete usage + ConverseStreamEvent::MessageStop(stop_event) => { + let anthropic_stop_reason = match stop_event.stop_reason { + crate::apis::amazon_bedrock::StopReason::EndTurn => MessagesStopReason::EndTurn, + crate::apis::amazon_bedrock::StopReason::ToolUse => MessagesStopReason::ToolUse, + crate::apis::amazon_bedrock::StopReason::MaxTokens => MessagesStopReason::MaxTokens, + crate::apis::amazon_bedrock::StopReason::StopSequence => MessagesStopReason::EndTurn, + crate::apis::amazon_bedrock::StopReason::GuardrailIntervened => MessagesStopReason::Refusal, + crate::apis::amazon_bedrock::StopReason::ContentFiltered => MessagesStopReason::Refusal, + }; + + Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: anthropic_stop_reason, + stop_sequence: None, + }, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }) + } + + // Metadata - convert usage information to MessageDelta + ConverseStreamEvent::Metadata(metadata_event) => { + Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: MessagesStopReason::EndTurn, + stop_sequence: None, + }, + usage: MessagesUsage { + input_tokens: metadata_event.usage.input_tokens, + output_tokens: metadata_event.usage.output_tokens, + cache_creation_input_tokens: metadata_event.usage.cache_write_input_tokens, + cache_read_input_tokens: metadata_event.usage.cache_read_input_tokens, + }, + }) + } + + // Exception events - convert to Ping (could be enhanced to return error events) + ConverseStreamEvent::InternalServerException(_) + | ConverseStreamEvent::ModelStreamErrorException(_) + | ConverseStreamEvent::ServiceUnavailableException(_) + | ConverseStreamEvent::ThrottlingException(_) + | ConverseStreamEvent::ValidationException(_) => { + // TODO: Consider adding proper error handling/events + Ok(MessagesStreamEvent::Ping) + } + } + } +} + +/// Convert tool call deltas to Anthropic stream events +fn convert_tool_call_deltas( + tool_calls: Vec, +) -> Result { + for tool_call in tool_calls { + if let Some(id) = &tool_call.id { + // Tool call start + if let Some(function) = &tool_call.function { + if let Some(name) = &function.name { + return Ok(MessagesStreamEvent::ContentBlockStart { + index: tool_call.index, + content_block: MessagesContentBlock::ToolUse { + id: id.clone(), + name: name.clone(), + input: Value::Object(serde_json::Map::new()), + cache_control: None, + }, + }); + } + } + } else if let Some(function) = &tool_call.function { + if let Some(arguments) = &function.arguments { + // Tool arguments delta + return Ok(MessagesStreamEvent::ContentBlockDelta { + index: tool_call.index, + delta: MessagesContentDelta::InputJsonDelta { + partial_json: arguments.clone(), + }, + }); + } + } + } + + // Fallback to ping if no valid tool call found + Ok(MessagesStreamEvent::Ping) +} diff --git a/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs b/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs new file mode 100644 index 00000000..9e2f083e --- /dev/null +++ b/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs @@ -0,0 +1,527 @@ +use crate::apis::amazon_bedrock::{ ConverseStreamEvent, StopReason}; +use crate::apis::anthropic::{ + MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesStreamEvent}; +use crate::apis::openai::{ ChatCompletionsStreamResponse,FinishReason, + FunctionCallDelta, MessageDelta, Role, StreamChoice, ToolCallDelta, Usage, +}; +use crate::apis::openai_responses::ResponsesAPIStreamEvent; + +use crate::clients::TransformError; +use crate::transforms::lib::*; + +// ============================================================================ +// PROVIDER STREAMING TRANSFORMATIONS TO OPENAI FORMAT +// ============================================================================ +// +// This module handles business logic for converting streaming events from +// various providers (Anthropic, Bedrock, etc.) into OpenAI's ChatCompletions format. +// +// # Architecture Separation +// +// **Provider Transformations** (this module): +// - Business logic for converting between provider formats +// - Uses Rust traits (TryFrom, Into) for type-safe conversions +// - Stateless event-by-event transformation +// - Example: MessagesStreamEvent → ChatCompletionsStreamResponse +// +// **Wire Format Buffering** (`apis/streaming_shapes/`): +// - SSE protocol handling (data:, event: lines) +// - State accumulation and lifecycle management +// - Buffering for stateful APIs (v1/responses) +// - Example: ChatCompletionsToResponsesTransformer +// +// # Flow +// +// ```text +// Anthropic Event → [Provider Transform] → OpenAI Event → [Wire Buffer] → SSE Wire Format +// (business) (this module) (protocol) (streaming_shapes) (network) +// ``` +// +// ============================================================================ + +impl TryFrom for ChatCompletionsStreamResponse { + type Error = TransformError; + + fn try_from(event: MessagesStreamEvent) -> Result { + match event { + MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk( + &message.id, + &message.model, + MessageDelta { + role: Some(Role::Assistant), + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + + MessagesStreamEvent::ContentBlockStart { content_block, .. } => { + convert_content_block_start(content_block) + } + + MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta), + + MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()), + + MessagesStreamEvent::MessageDelta { delta, usage } => { + let finish_reason: Option = Some(delta.stop_reason.into()); + let openai_usage: Option = Some(usage.into()); + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + finish_reason, + openai_usage, + )) + } + + MessagesStreamEvent::MessageStop => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + Some(FinishReason::Stop), + None, + )), + + MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse { + id: "stream".to_string(), + object: Some("chat.completion.chunk".to_string()), + created: current_timestamp(), + model: "unknown".to_string(), + choices: vec![], + usage: None, + system_fingerprint: None, + service_tier: None, + }), + } + } +} + +impl TryFrom for ChatCompletionsStreamResponse { + type Error = TransformError; + + fn try_from(event: ConverseStreamEvent) -> Result { + match event { + ConverseStreamEvent::MessageStart(start_event) => { + let role = match start_event.role { + crate::apis::amazon_bedrock::ConversationRole::User => Role::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant, + }; + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: Some(role), + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )) + } + + ConverseStreamEvent::ContentBlockStart(start_event) => { + use crate::apis::amazon_bedrock::ContentBlockStart; + + match start_event.start { + ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: start_event.content_block_index as u32, + id: Some(tool_use.tool_use_id), + call_type: Some("function".to_string()), + function: Some(FunctionCallDelta { + name: Some(tool_use.name), + arguments: Some("".to_string()), + }), + }]), + }, + None, + None, + )), + } + } + + ConverseStreamEvent::ContentBlockDelta(delta_event) => { + use crate::apis::amazon_bedrock::ContentBlockDelta; + + match delta_event.delta { + ContentBlockDelta::Text { text } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(text), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: delta_event.content_block_index as u32, + id: None, + call_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: Some(tool_use.input), + }), + }]), + }, + None, + None, + )), + } + } + + ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()), + + ConverseStreamEvent::MessageStop(stop_event) => { + let finish_reason = match stop_event.stop_reason { + StopReason::EndTurn => FinishReason::Stop, + StopReason::ToolUse => FinishReason::ToolCalls, + StopReason::MaxTokens => FinishReason::Length, + StopReason::StopSequence => FinishReason::Stop, + StopReason::GuardrailIntervened => FinishReason::ContentFilter, + StopReason::ContentFiltered => FinishReason::ContentFilter, + }; + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + Some(finish_reason), + None, + )) + } + + ConverseStreamEvent::Metadata(metadata_event) => { + let usage = Usage { + prompt_tokens: metadata_event.usage.input_tokens, + completion_tokens: metadata_event.usage.output_tokens, + total_tokens: metadata_event.usage.total_tokens, + prompt_tokens_details: None, + completion_tokens_details: None, + }; + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + Some(usage), + )) + } + + // Error events - convert to empty chunks (errors should be handled elsewhere) + ConverseStreamEvent::InternalServerException(_) + | ConverseStreamEvent::ModelStreamErrorException(_) + | ConverseStreamEvent::ServiceUnavailableException(_) + | ConverseStreamEvent::ThrottlingException(_) + | ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()), + } + } +} + +/// Convert content block start to OpenAI chunk +fn convert_content_block_start( + content_block: MessagesContentBlock, +) -> Result { + match content_block { + MessagesContentBlock::Text { .. } => { + // No immediate output for text block start + Ok(create_empty_openai_chunk()) + } + MessagesContentBlock::ToolUse { id, name, .. } + | MessagesContentBlock::ServerToolUse { id, name, .. } + | MessagesContentBlock::McpToolUse { id, name, .. } => { + // Tool use start → OpenAI chunk with tool_calls + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: 0, + id: Some(id), + call_type: Some("function".to_string()), + function: Some(FunctionCallDelta { + name: Some(name), + arguments: Some("".to_string()), + }), + }]), + }, + None, + None, + )) + } + _ => Err(TransformError::UnsupportedContent( + "Unsupported content block type in stream start".to_string(), + )), + } +} + +/// Convert content delta to OpenAI chunk +fn convert_content_delta( + delta: MessagesContentDelta, +) -> Result { + match delta { + MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(text), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(format!("thinking: {}", thinking)), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: 0, + id: None, + call_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: Some(partial_json), + }), + }]), + }, + None, + None, + )), + } +} + +/// Helper to create OpenAI streaming chunk +fn create_openai_chunk( + id: &str, + model: &str, + delta: MessageDelta, + finish_reason: Option, + usage: Option, +) -> ChatCompletionsStreamResponse { + ChatCompletionsStreamResponse { + id: id.to_string(), + object: Some("chat.completion.chunk".to_string()), + created: current_timestamp(), + model: model.to_string(), + choices: vec![StreamChoice { + index: 0, + delta, + finish_reason, + logprobs: None, + }], + usage, + system_fingerprint: None, + service_tier: None, + } +} + +/// Helper to create empty OpenAI streaming chunk +fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse { + create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + ) +} + +// Stop Reason Conversions +impl Into for MessagesStopReason { + fn into(self) -> FinishReason { + match self { + MessagesStopReason::EndTurn => FinishReason::Stop, + MessagesStopReason::MaxTokens => FinishReason::Length, + MessagesStopReason::StopSequence => FinishReason::Stop, + MessagesStopReason::ToolUse => FinishReason::ToolCalls, + MessagesStopReason::PauseTurn => FinishReason::Stop, + MessagesStopReason::Refusal => FinishReason::ContentFilter, + } + } +} + +impl TryFrom for ResponsesAPIStreamEvent { + type Error = TransformError; + + fn try_from(chunk: ChatCompletionsStreamResponse) -> Result { + // Stateless conversion - just extract the delta information + // The buffer will manage state, item IDs, and sequence numbers + + // Extract first choice if available + if let Some(choice) = chunk.choices.first() { + let delta = &choice.delta; + + // Tool call with function name and/or arguments + if let Some(tool_calls) = &delta.tool_calls { + if let Some(tool_call) = tool_calls.first() { + // Extract call_id and name if available (metadata from initial event) + let call_id = tool_call.id.clone(); + let function_name = tool_call.function.as_ref() + .and_then(|f| f.name.clone()); + + // Check if we have function metadata (name, id) + if let Some(function) = &tool_call.function { + // If we have arguments delta, return that + if let Some(args) = &function.arguments { + return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { + output_index: choice.index as i32, + item_id: "".to_string(), // Buffer will fill this + delta: args.clone(), + sequence_number: 0, // Buffer will fill this + call_id, + name: function_name, + }); + } + + // If we have function name but no arguments yet (initial tool call event) + // Return an empty arguments delta so the buffer knows to create the item + if function.name.is_some() { + return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { + output_index: choice.index as i32, + item_id: "".to_string(), // Buffer will fill this + delta: "".to_string(), // Empty delta signals this is the initial event + sequence_number: 0, // Buffer will fill this + call_id, + name: function_name, + }); + } + } + } + } + + // Text content delta + if let Some(content) = &delta.content { + if !content.is_empty() { + return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta { + item_id: "".to_string(), // Buffer will fill this + output_index: choice.index as i32, + content_index: 0, + delta: content.clone(), + logprobs: vec![], + obfuscation: None, + sequence_number: 0, // Buffer will fill this + }); + } + } + + // Handle finish_reason - this is a completion signal + // Return an empty delta that the buffer can use to detect completion + if choice.finish_reason.is_some() { + // Return a minimal text delta to signal completion + // The buffer will handle the finish_reason and generate response.completed + return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta { + item_id: "".to_string(), // Buffer will fill this + output_index: choice.index as i32, + content_index: 0, + delta: "".to_string(), // Empty delta signals completion + logprobs: vec![], + obfuscation: None, + sequence_number: 0, // Buffer will fill this + }); + } + + // Empty delta with role only (common at stream start) + if delta.role.is_some() { + // This is typically the first chunk establishing the assistant role + // Return an empty text delta that the buffer can use to initialize state + return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta { + item_id: "".to_string(), + output_index: choice.index as i32, + content_index: 0, + delta: "".to_string(), + logprobs: vec![], + obfuscation: None, + sequence_number: 0, + }); + } + } + + // Empty chunk or no convertible content (e.g., keep-alive chunks with delta: {}) + // These are valid in OpenAI streaming and should be silently ignored + // Return error so the caller can skip these chunks without warnings + Err(TransformError::UnsupportedConversion( + "Empty or keep-alive chunk with no convertible content".to_string(), + )) + } +} diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 1098185d..42d7cb31 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -22,11 +22,13 @@ use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; -use hermesllm::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; -use hermesllm::apis::anthropic::{MessagesContentBlock, MessagesStreamEvent}; -use hermesllm::apis::sse::{SseEvent, SseStreamIter}; -use hermesllm::clients::endpoints::SupportedAPIs; +use hermesllm::apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; +use hermesllm::apis::streaming_shapes::sse::{ + SseEvent, SseStreamBuffer, SseStreamBufferTrait, SseStreamIter, +}; +use hermesllm::clients::endpoints::SupportedAPIsFromClient; use hermesllm::providers::response::ProviderResponse; +use hermesllm::providers::streaming_response::ProviderStreamResponse; use hermesllm::{ DecodedFrame, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType, ProviderStreamResponseType, @@ -38,7 +40,7 @@ pub struct StreamContext { streaming_response: bool, response_tokens: usize, /// The API that is requested by the client (before compatibility mapping) - client_api: Option, + client_api: Option, /// The API that should be used for the upstream provider (after compatibility mapping) resolved_api: Option, llm_providers: Rc, @@ -56,6 +58,7 @@ pub struct StreamContext { binary_frame_decoder: Option>, http_method: Option, http_protocol: Option, + sse_buffer: Option, } impl StreamContext { @@ -87,6 +90,7 @@ impl StreamContext { binary_frame_decoder: None, http_method: None, http_protocol: None, + sse_buffer: None, } } @@ -172,7 +176,8 @@ impl StreamContext { Some( SupportedUpstreamAPIs::OpenAIChatCompletions(_) | SupportedUpstreamAPIs::AmazonBedrockConverse(_) - | SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + | SupportedUpstreamAPIs::AmazonBedrockConverseStream(_) + | SupportedUpstreamAPIs::OpenAIResponsesAPI(_), ) | None => { // OpenAI and default: use Authorization Bearer token @@ -476,7 +481,17 @@ impl StreamContext { } }; - let mut response_buffer = Vec::new(); + // Initialize SSE buffer if not present + if self.sse_buffer.is_none() { + self.sse_buffer = match SseStreamBuffer::try_from((&client_api, &upstream_api)) + { + Ok(buffer) => Some(buffer), + Err(e) => { + warn!("Failed to create SSE buffer: {}", e); + return Err(Action::Continue); + } + }; + } // Process each SSE event for sse_event in sse_iter { @@ -527,12 +542,32 @@ impl StreamContext { } } - // Add transformed event to response buffer - let bytes: Vec = transformed_event.into(); - response_buffer.extend_from_slice(&bytes); + // Add transformed event to buffer (buffer may inject lifecycle events) + if let Some(buffer) = self.sse_buffer.as_mut() { + buffer.add_transformed_event(transformed_event); + } } - Ok(response_buffer) + // Get accumulated bytes from buffer and return + match self.sse_buffer.as_mut() { + Some(buffer) => { + let bytes = buffer.into_bytes(); + if !bytes.is_empty() { + let content = String::from_utf8_lossy(&bytes); + debug!( + "[ARCHGW_REQ_ID:{}] UPSTREAM_TRANSFORMED_CLIENT_RESPONSE: size={} content={}", + self.request_identifier(), + bytes.len(), + content + ); + } + Ok(bytes) + } + None => { + warn!("SSE buffer unexpectedly missing after initialization"); + Err(Action::Continue) + } + } } None => { warn!("Missing client_api for non-streaming response"); @@ -544,7 +579,7 @@ impl StreamContext { fn handle_bedrock_binary_stream( &mut self, body: &[u8], - client_api: &SupportedAPIs, + client_api: &SupportedAPIsFromClient, upstream_api: &SupportedUpstreamAPIs, ) -> Result, Action> { // Initialize decoder if not present @@ -552,83 +587,57 @@ impl StreamContext { self.binary_frame_decoder = Some(BedrockBinaryFrameDecoder::from_bytes(&[])); } - // Add incoming bytes to buffer + // Initialize SSE buffer if not present + if self.sse_buffer.is_none() { + self.sse_buffer = match SseStreamBuffer::try_from((client_api, upstream_api)) { + Ok(buffer) => Some(buffer), + Err(e) => { + warn!( + "[ARCHGW_REQ_ID:{}] BEDROCK_BUFFER_INIT_ERROR: {}", + self.request_identifier(), + e + ); + return Err(Action::Continue); + } + }; + } + + // Add incoming bytes to decoder buffer let decoder = self.binary_frame_decoder.as_mut().unwrap(); decoder.buffer_mut().extend_from_slice(body); - let mut response_buffer = Vec::new(); + // Process all complete frames loop { let decoded_frame = self.binary_frame_decoder.as_mut().unwrap().decode_frame(); match decoded_frame { Some(DecodedFrame::Complete(ref frame_ref)) => { let frame = DecodedFrame::Complete(frame_ref.clone()); + + // Convert frame to provider response type match ProviderStreamResponseType::try_from((&frame, client_api, upstream_api)) { Ok(provider_response) => { self.record_ttft_if_needed(); - // Handle ContentBlockStart and ContentBlockDelta events - match &provider_response { - ProviderStreamResponseType::MessagesStreamEvent(evt) => { - match evt { - MessagesStreamEvent::ContentBlockStart { - index, .. - } => { - // Mark that we've seen ContentBlockStart for this index - self.binary_frame_decoder - .as_mut() - .unwrap() - .set_content_block_start_sent(*index as i32); - debug!( - "[ARCHGW_REQ_ID:{}] BEDROCK_CONTENT_BLOCK_START_TRACKED: index={}", - self.request_identifier(), - *index - ); - } - MessagesStreamEvent::ContentBlockDelta { - index, .. - } => { - // Check if ContentBlockStart was sent for this index - let needs_start = !self - .binary_frame_decoder - .as_ref() - .unwrap() - .has_content_block_start_been_sent(*index as i32); - - if needs_start { - // Emit empty ContentBlockStart before delta - let content_block_start = - MessagesStreamEvent::ContentBlockStart { - index: *index, - content_block: MessagesContentBlock::Text { - text: String::new(), - cache_control: None, - }, - }; - let start_sse: String = content_block_start.into(); - response_buffer - .extend_from_slice(start_sse.as_bytes()); - - // Mark that we've now sent it - self.binary_frame_decoder - .as_mut() - .unwrap() - .set_content_block_start_sent(*index as i32); - - debug!( - "[ARCHGW_REQ_ID:{}] BEDROCK_INJECTED_CONTENT_BLOCK_START: index={}", - self.request_identifier(), - *index - ); - } - } - _ => {} - } - } - _ => {} + // Track token usage + if let Some(content) = provider_response.content_delta() { + let estimated_tokens = content.len() / 4; + self.response_tokens += estimated_tokens.max(1); + debug!( + "[ARCHGW_REQ_ID:{}] BEDROCK_TOKEN_UPDATE: delta_chars={} estimated_tokens={} total_tokens={}", + self.request_identifier(), + content.len(), + estimated_tokens.max(1), + self.response_tokens + ); } - let sse_string: String = provider_response.into(); - response_buffer.extend_from_slice(sse_string.as_bytes()); + // Create SseEvent from provider response + let event = SseEvent::from_provider_response(provider_response); + + // Add to buffer (buffer handles all shim logic including ContentBlockStart injection) + if let Some(buffer) = self.sse_buffer.as_mut() { + buffer.add_transformed_event(event); + } } Err(e) => { warn!( @@ -658,8 +667,29 @@ impl StreamContext { } } - // Return accumulated complete frames (may be empty if all frames incomplete) - Ok(response_buffer) + // Get accumulated bytes from buffer and return + match self.sse_buffer.as_mut() { + Some(buffer) => { + let bytes = buffer.into_bytes(); + if !bytes.is_empty() { + let content = String::from_utf8_lossy(&bytes); + debug!( + "[ARCHGW_REQ_ID:{}] UPSTREAM_TRANSFORMED_CLIENT_RESPONSE: size={} content={}", + self.request_identifier(), + bytes.len(), + content + ); + } + Ok(bytes) + } + None => { + warn!( + "[ARCHGW_REQ_ID:{}] BEDROCK_BUFFER_MISSING", + self.request_identifier() + ); + Err(Action::Continue) + } + } } fn handle_non_streaming_response( @@ -782,13 +812,14 @@ impl HttpContext for StreamContext { self.select_llm_provider(); // Check if this is a supported API endpoint - if SupportedAPIs::from_endpoint(&request_path).is_none() { + if SupportedAPIsFromClient::from_endpoint(&request_path).is_none() { self.send_http_response(404, vec![], Some(b"Unsupported endpoint")); return Action::Continue; } // Get the SupportedApi for routing decisions - let supported_api: Option = SupportedAPIs::from_endpoint(&request_path); + let supported_api: Option = + SupportedAPIsFromClient::from_endpoint(&request_path); self.client_api = supported_api; // Debug: log provider, client API, resolved API, and request path @@ -1131,8 +1162,9 @@ impl HttpContext for StreamContext { } match self.client_api { - Some(SupportedAPIs::OpenAIChatCompletions(_)) => {} - Some(SupportedAPIs::AnthropicMessagesAPI(_)) => {} + Some(SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {} + Some(SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {} + Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {} _ => { let api_info = match &self.client_api { Some(api) => format!("{}", api), diff --git a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml index b9dcab81..0aaaa537 100644 --- a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml +++ b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml @@ -47,6 +47,9 @@ llm_providers: - model: ollama/llama3.1 base_url: http://host.docker.internal:11434 + # Grok (xAI) Models + - model: xai/grok-4-0709 + access_key: $GROK_API_KEY # Model aliases - friendly names that map to actual provider names model_aliases: @@ -83,5 +86,9 @@ model_aliases: coding-model: target: us.amazon.nova-premier-v1:0 + # Alias for grok testing + arch.grok.v1: + target: grok-4-0709 + tracing: random_sampling: 100 diff --git a/tests/e2e/run_e2e_tests.sh b/tests/e2e/run_e2e_tests.sh index 856b16ab..f60f79bc 100644 --- a/tests/e2e/run_e2e_tests.sh +++ b/tests/e2e/run_e2e_tests.sh @@ -65,6 +65,10 @@ log running e2e tests for model alias routing log ======================================== poetry run pytest test_model_alias_routing.py +log running e2e tests for openai responses api client +log ======================================== +poetry run pytest test_openai_responses_api_client.py + log shutting down the weather_forecast demo log ======================================= cd ../../demos/samples_python/weather_forecast diff --git a/tests/e2e/test_openai_responses_api_client.py b/tests/e2e/test_openai_responses_api_client.py new file mode 100644 index 00000000..800db93d --- /dev/null +++ b/tests/e2e/test_openai_responses_api_client.py @@ -0,0 +1,630 @@ +import openai +import pytest +import os +import logging +import sys + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) + +LLM_GATEWAY_ENDPOINT = os.getenv( + "LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1/chat/completions" +) + + +# ----------------------- +# v1/responses API tests +# ----------------------- +def test_openai_responses_api_non_streaming_passthrough(): + """Build a v1/responses API request (pass-through) and ensure gateway accepts it""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + # Simple responses API request using a direct model (pass-through) + resp = client.responses.create( + model="gpt-4o", input="Hello via responses passthrough" + ) + + # Print the response content - handle both responses format and chat completions format + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + # Minimal sanity checks + assert resp is not None + assert ( + getattr(resp, "id", None) is not None + or getattr(resp, "output", None) is not None + ) + + +def test_openai_responses_api_with_streaming_passthrough(): + """Build a v1/responses API streaming request (pass-through) and ensure gateway accepts it""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + # Simple streaming responses API request using a direct model (pass-through) + stream = client.responses.create( + model="gpt-4o", + input="Write a short haiku about coding", + stream=True, + ) + + # Collect streamed content using the official Responses API streaming shape + text_chunks = [] + final_message = None + + for event in stream: + # The Python SDK surfaces a high-level Responses streaming interface. + # We rely on its typed helpers instead of digging into model_extra. + if getattr(event, "type", None) == "response.output_text.delta" and getattr( + event, "delta", None + ): + # Each delta contains a text fragment + text_chunks.append(event.delta) + + # Track the final response message if provided by the SDK + if getattr(event, "type", None) == "response.completed" and getattr( + event, "response", None + ): + final_message = event.response + + full_content = "".join(text_chunks) + + # Print the streaming response + print(f"\n{'='*80}") + print( + f"Model: {getattr(final_message, 'model', 'unknown') if final_message else 'unknown'}" + ) + print(f"Streamed Output: {full_content}") + print(f"{'='*80}\n") + + assert len(text_chunks) > 0, "Should have received streaming text deltas" + assert len(full_content) > 0, "Should have received content" + + +def test_openai_responses_api_non_streaming_with_tools_passthrough(): + """Responses API with a function/tool definition (pass-through)""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0) + + # Define a simple tool/function for the Responses API + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + resp = client.responses.create( + model="gpt-5", + input="Call the echo tool", + tools=tools, + ) + + assert resp is not None + assert ( + getattr(resp, "id", None) is not None + or getattr(resp, "output", None) is not None + ) + + +def test_openai_responses_api_with_streaming_with_tools_passthrough(): + """Responses API with a function/tool definition (streaming, pass-through)""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0) + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + stream = client.responses.create( + model="gpt-5", + input="Call the echo tool", + tools=tools, + stream=True, + ) + + text_chunks = [] + tool_calls = [] + + for event in stream: + etype = getattr(event, "type", None) + + # Collect streamed text output + if etype == "response.output_text.delta" and getattr(event, "delta", None): + text_chunks.append(event.delta) + + # Collect streamed tool call arguments + if etype == "response.function_call_arguments.delta" and getattr( + event, "delta", None + ): + tool_calls.append(event.delta) + + full_text = "".join(text_chunks) + + print(f"\n{'='*80}") + print("Responses tools streaming test") + print(f"Streamed text: {full_text}") + print(f"Tool call argument chunks: {len(tool_calls)}") + print(f"{'='*80}\n") + + # We expect either streamed text output or streamed tool-call arguments + assert ( + full_text or tool_calls + ), "Expected streamed text or tool call argument deltas from Responses tools stream" + + +def test_openai_responses_api_non_streaming_upstream_chat_completions(): + """Send a v1/responses request using the grok alias to verify translation/routing""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + resp = client.responses.create( + model="arch.grok.v1", input="Hello, translate this via grok alias" + ) + + # Print the response content - handle both responses format and chat completions format + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + assert resp is not None + assert resp.id is not None + + +def test_openai_responses_api_with_streaming_upstream_chat_completions(): + """Build a v1/responses API streaming request (pass-through) and ensure gateway accepts it""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + # Simple streaming responses API request using a direct model (pass-through) + stream = client.responses.create( + model="arch.grok.v1", + input="Write a short haiku about coding", + stream=True, + ) + + # Collect streamed content using the official Responses API streaming shape + text_chunks = [] + final_message = None + + for event in stream: + # The Python SDK surfaces a high-level Responses streaming interface. + # We rely on its typed helpers instead of digging into model_extra. + if getattr(event, "type", None) == "response.output_text.delta" and getattr( + event, "delta", None + ): + # Each delta contains a text fragment + text_chunks.append(event.delta) + + # Track the final response message if provided by the SDK + if getattr(event, "type", None) == "response.completed" and getattr( + event, "response", None + ): + final_message = event.response + + full_content = "".join(text_chunks) + + # Print the streaming response + print(f"\n{'='*80}") + print( + f"Model: {getattr(final_message, 'model', 'unknown') if final_message else 'unknown'}" + ) + print(f"Streamed Output: {full_content}") + print(f"{'='*80}\n") + + assert len(text_chunks) > 0, "Should have received streaming text deltas" + assert len(full_content) > 0, "Should have received content" + + +def test_openai_responses_api_non_streaming_with_tools_upstream_chat_completions(): + """Responses API wioutputling routed to grok via alias""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + resp = client.responses.create( + model="arch.grok.v1", + input="Call the echo tool", + tools=tools, + ) + + assert resp.id is not None + + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + +def test_openai_responses_api_streaming_with_tools_upstream_chat_completions(): + """Responses API with a function/tool definition (streaming, pass-through)""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0) + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + stream = client.responses.create( + model="arch.grok.v1", + input="Call the echo tool", + tools=tools, + stream=True, + ) + + text_chunks = [] + tool_calls = [] + + for event in stream: + etype = getattr(event, "type", None) + + # Collect streamed text output + if etype == "response.output_text.delta" and getattr(event, "delta", None): + text_chunks.append(event.delta) + + # Collect streamed tool call arguments + if etype == "response.function_call_arguments.delta" and getattr( + event, "delta", None + ): + tool_calls.append(event.delta) + + full_text = "".join(text_chunks) + + print(f"\n{'='*80}") + print("Responses tools streaming test") + print(f"Streamed text: {full_text}") + print(f"Tool call argument chunks: {len(tool_calls)}") + print(f"{'='*80}\n") + + # We expect either streamed text output or streamed tool-call arguments + assert ( + full_text or tool_calls + ), "Expected streamed text or tool call argument deltas from Responses tools stream" + + +def test_openai_responses_api_non_streaming_upstream_bedrock(): + """Send a v1/responses request using the coding-model alias to verify Bedrock translation/routing""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + resp = client.responses.create( + model="coding-model", + input="Hello, translate this via coding-model alias to Bedrock", + ) + + # Print the response content - handle both responses format and chat completions format + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + assert resp is not None + assert resp.id is not None + + +def test_openai_responses_api_with_streaming_upstream_bedrock(): + """Build a v1/responses API streaming request routed to Bedrock via coding-model alias""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + # Simple streaming responses API request using coding-model alias + stream = client.responses.create( + model="coding-model", + input="Write a short haiku about coding", + stream=True, + ) + + # Collect streamed content using the official Responses API streaming shape + text_chunks = [] + final_message = None + + for event in stream: + # The Python SDK surfaces a high-level Responses streaming interface. + # We rely on its typed helpers instead of digging into model_extra. + if getattr(event, "type", None) == "response.output_text.delta" and getattr( + event, "delta", None + ): + # Each delta contains a text fragment + text_chunks.append(event.delta) + + # Track the final response message if provided by the SDK + if getattr(event, "type", None) == "response.completed" and getattr( + event, "response", None + ): + final_message = event.response + + full_content = "".join(text_chunks) + + # Print the streaming response + print(f"\n{'='*80}") + print( + f"Model: {getattr(final_message, 'model', 'unknown') if final_message else 'unknown'}" + ) + print(f"Streamed Output: {full_content}") + print(f"{'='*80}\n") + + assert len(text_chunks) > 0, "Should have received streaming text deltas" + assert len(full_content) > 0, "Should have received content" + + +def test_openai_responses_api_non_streaming_with_tools_upstream_bedrock(): + """Responses API with tools routed to Bedrock via coding-model alias""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + resp = client.responses.create( + model="coding-model", + input="Call the echo tool", + tools=tools, + ) + + assert resp.id is not None + + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + +def test_openai_responses_api_streaming_with_tools_upstream_bedrock(): + """Responses API with a function/tool definition streaming to Bedrock via coding-model alias""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0) + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + stream = client.responses.create( + model="coding-model", + input="Call the echo tool", + tools=tools, + stream=True, + ) + + text_chunks = [] + tool_calls = [] + + for event in stream: + etype = getattr(event, "type", None) + + # Collect streamed text output + if etype == "response.output_text.delta" and getattr(event, "delta", None): + text_chunks.append(event.delta) + + # Collect streamed tool call arguments + if etype == "response.function_call_arguments.delta" and getattr( + event, "delta", None + ): + tool_calls.append(event.delta) + + full_text = "".join(text_chunks) + + print(f"\n{'='*80}") + print("Responses tools streaming test (Bedrock)") + print(f"Streamed text: {full_text}") + print(f"Tool call argument chunks: {len(tool_calls)}") + print(f"{'='*80}\n") + + # We expect either streamed text output or streamed tool-call arguments + assert ( + full_text or tool_calls + ), "Expected streamed text or tool call argument deltas from Responses tools stream" + + +def test_openai_responses_api_non_streaming_upstream_anthropic(): + """Send a v1/responses request using the grok alias to verify translation/routing""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + resp = client.responses.create( + model="claude-sonnet-4-20250514", input="Hello, translate this via grok alias" + ) + + # Print the response content - handle both responses format and chat completions format + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + assert resp is not None + assert resp.id is not None + + +def test_openai_responses_api_with_streaming_upstream_anthropic(): + """Build a v1/responses API streaming request (pass-through) and ensure gateway accepts it""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + # Simple streaming responses API request using a direct model (pass-through) + stream = client.responses.create( + model="claude-sonnet-4-20250514", + input="Write a short haiku about coding", + stream=True, + ) + + # Collect streamed content using the official Responses API streaming shape + text_chunks = [] + final_message = None + + for event in stream: + # The Python SDK surfaces a high-level Responses streaming interface. + # We rely on its typed helpers instead of digging into model_extra. + if getattr(event, "type", None) == "response.output_text.delta" and getattr( + event, "delta", None + ): + # Each delta contains a text fragment + text_chunks.append(event.delta) + + # Track the final response message if provided by the SDK + if getattr(event, "type", None) == "response.completed" and getattr( + event, "response", None + ): + final_message = event.response + + full_content = "".join(text_chunks) + + # Print the streaming response + print(f"\n{'='*80}") + print( + f"Model: {getattr(final_message, 'model', 'unknown') if final_message else 'unknown'}" + ) + print(f"Streamed Output: {full_content}") + print(f"{'='*80}\n") + + assert len(text_chunks) > 0, "Should have received streaming text deltas" + assert len(full_content) > 0, "Should have received content" + + +def test_openai_responses_api_non_streaming_with_tools_upstream_anthropic(): + """Responses API with tools routed to grok via alias""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input: hello_world", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + resp = client.responses.create( + model="claude-sonnet-4-20250514", + input="Call the echo tool", + tools=tools, + ) + + assert resp.id is not None + + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + +def test_openai_responses_api_streaming_with_tools_upstream_anthropic(): + """Responses API with a function/tool definition (streaming, pass-through)""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0) + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input: hello_world", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + stream = client.responses.create( + model="claude-sonnet-4-20250514", + input="Call the echo tool", + tools=tools, + stream=True, + ) + + text_chunks = [] + tool_calls = [] + + for event in stream: + etype = getattr(event, "type", None) + + # Collect streamed text output + if etype == "response.output_text.delta" and getattr(event, "delta", None): + text_chunks.append(event.delta) + + # Collect streamed tool call arguments + if etype == "response.function_call_arguments.delta" and getattr( + event, "delta", None + ): + tool_calls.append(event.delta) + + full_text = "".join(text_chunks) + + print(f"\n{'='*80}") + print("Responses tools streaming test") + print(f"Streamed text: {full_text}") + print(f"Tool call argument chunks: {len(tool_calls)}") + print(f"{'='*80}\n") + + # We expect either streamed text output or streamed tool-call arguments + assert ( + full_text or tool_calls + ), "Expected streamed text or tool call argument deltas from Responses tools stream"