diff --git a/.github/workflows/e2e_archgw.yml b/.github/workflows/e2e_archgw.yml index 4509e99e..f13e126f 100644 --- a/.github/workflows/e2e_archgw.yml +++ b/.github/workflows/e2e_archgw.yml @@ -39,6 +39,8 @@ jobs: GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} + AWS_BEARER_TOKEN_BEDROCK: ${{ secrets.AWS_BEARER_TOKEN_BEDROCK }} + run: | docker compose up | tee &> archgw.logs & diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index 3319f145..1538a964 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -32,6 +32,7 @@ jobs: GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} + AWS_BEARER_TOKEN_BEDROCK: ${{ secrets.AWS_BEARER_TOKEN_BEDROCK }} run: | python -mvenv venv source venv/bin/activate && cd tests/e2e && bash run_e2e_tests.sh diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index c8479278..60ea8537 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -23,6 +23,7 @@ SUPPORTED_PROVIDERS = [ "moonshotai", "zhipu", "qwen", + "amazon_bedrock", ] @@ -188,7 +189,10 @@ def validate_and_render_schema(): # Validate azure_openai and ollama provider requires base_url if ( - provider == "azure_openai" or provider == "ollama" or provider == "qwen" + provider == "azure_openai" + or provider == "ollama" + or provider == "qwen" + or provider == "amazon_bedrock" ) and model_provider.get("base_url") is None: raise Exception( f"Provider '{provider}' requires 'base_url' to be set for model {model_name}" diff --git a/crates/Cargo.lock b/crates/Cargo.lock index d313aa5b..0115151e 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -101,6 +101,35 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-smithy-eventstream" +version = "0.60.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9656b85088f8d9dc7ad40f9a6c7228e1e8447cdf4b046c87e152e0805dea02fa" +dependencies = [ + "aws-smithy-types", + "bytes", + "crc32fast", +] + +[[package]] +name = "aws-smithy-types" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f5b3a7486f6690ba25952cabf1e7d75e34d69eaff5081904a47bc79074d6457" +dependencies = [ + "base64-simd", + "bytes", + "bytes-utils", + "itoa", + "num-integer", + "pin-project-lite", + "pin-utils", + "ryu", + "serde", + "time", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -128,6 +157,16 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195" +dependencies = [ + "outref", + "vsimd", +] + [[package]] name = "bit-set" version = "0.5.3" @@ -217,6 +256,16 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +[[package]] +name = "bytes-utils" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" +dependencies = [ + "bytes", + "either", +] + [[package]] name = "cc" version = "1.2.26" @@ -302,6 +351,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -738,6 +796,8 @@ dependencies = [ name = "hermesllm" version = "0.1.0" dependencies = [ + "aws-smithy-eventstream", + "bytes", "serde", "serde_json", "serde_with", @@ -1191,6 +1251,7 @@ name = "llm_gateway" version = "0.1.0" dependencies = [ "acap", + "bytes", "common", "derivative", "governor", @@ -1359,6 +1420,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -1517,6 +1587,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "outref" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" + [[package]] name = "overload" version = "0.1.1" @@ -2081,9 +2157,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.14.0" +version = "2.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" dependencies = [ "core-foundation-sys", "libc", @@ -2812,6 +2888,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + [[package]] name = "want" version = "0.3.1" diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index b96f1f52..1b15e389 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -1,7 +1,8 @@ use bytes::Bytes; use common::configuration::{ModelAlias, ModelUsagePreference}; -use common::consts::ARCH_PROVIDER_HINT_HEADER; +use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER}; use hermesllm::apis::openai::ChatCompletionsRequest; +use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use hermesllm::clients::SupportedAPIs; use hermesllm::{ProviderRequest, ProviderRequestType}; use http_body_util::combinators::BoxBody; @@ -56,6 +57,7 @@ pub async fn chat( // Model alias resolution: update model field in client_request immediately // This ensures all downstream objects use the resolved model let model_from_request = client_request.model().to_string(); + let is_streaming_request = client_request.is_streaming(); let resolved_model = if let Some(model_aliases) = model_aliases.as_ref() { if let Some(model_alias) = model_aliases.get(&model_from_request) { debug!( @@ -84,10 +86,16 @@ pub async fn chat( let chat_completions_request_for_arch_router: ChatCompletionsRequest = match ProviderRequestType::try_from(( client_request, - &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions), + &SupportedUpstreamAPIs::OpenAIChatCompletions( + hermesllm::apis::OpenAIApi::ChatCompletions, + ), )) { Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req, - Ok(ProviderRequestType::MessagesRequest(_)) => { + Ok( + ProviderRequestType::MessagesRequest(_) + | ProviderRequestType::BedrockConverse(_) + | ProviderRequestType::BedrockConverseStream(_), + ) => { // This should not happen after conversion to OpenAI format warn!("Unexpected: got MessagesRequest after converting to OpenAI format"); let err_msg = "Request conversion failed".to_string(); @@ -190,6 +198,11 @@ pub async fn chat( header::HeaderValue::from_str(&model_name).unwrap(), ); + request_headers.insert( + header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER), + header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(), + ); + if let Some(trace_parent) = trace_parent { request_headers.insert( header::HeaderName::from_static("traceparent"), diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index c881afa4..dc1b74e9 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -206,6 +206,8 @@ pub enum LlmProviderType { Zhipu, #[serde(rename = "qwen")] Qwen, + #[serde(rename = "amazon_bedrock")] + AmazonBedrock, } impl Display for LlmProviderType { @@ -225,6 +227,7 @@ impl Display for LlmProviderType { LlmProviderType::Moonshotai => write!(f, "moonshotai"), LlmProviderType::Zhipu => write!(f, "zhipu"), LlmProviderType::Qwen => write!(f, "qwen"), + LlmProviderType::AmazonBedrock => write!(f, "amazon_bedrock"), } } } diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 14972485..13624d8d 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -11,6 +11,7 @@ pub const MODEL_SERVER_NAME: &str = "model_server"; pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const MESSAGES_KEY: &str = "messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; +pub const ARCH_IS_STREAMING_HEADER: &str = "x-arch-streaming-request"; pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const MESSAGES_PATH: &str = "/v1/messages"; pub const HEALTHZ_PATH: &str = "/healthz"; diff --git a/crates/hermesllm/Cargo.toml b/crates/hermesllm/Cargo.toml index c995e85c..ab2390bf 100644 --- a/crates/hermesllm/Cargo.toml +++ b/crates/hermesllm/Cargo.toml @@ -6,5 +6,7 @@ edition = "2021" [dependencies] serde = {version = "1.0.219", features = ["derive"]} serde_json = "1.0.140" -serde_with = "3.12.0" +serde_with = {version = "3.12.0", features = ["base64"]} thiserror = "2.0.12" +aws-smithy-eventstream = "0.60" +bytes = "1.10" diff --git a/crates/hermesllm/src/apis/amazon_bedrock.rs b/crates/hermesllm/src/apis/amazon_bedrock.rs new file mode 100644 index 00000000..eb1f3ddf --- /dev/null +++ b/crates/hermesllm/src/apis/amazon_bedrock.rs @@ -0,0 +1,1149 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use serde_with::skip_serializing_none; + +use std::collections::HashMap; +use thiserror::Error; + +use super::ApiDefinition; +use crate::providers::request::{ProviderRequest, ProviderRequestError}; +use crate::providers::response::ProviderStreamResponse; + +// ============================================================================ +// AMAZON BEDROCK CONVERSE API ENUMERATION +// ============================================================================ + +/// Enum for all supported Amazon Bedrock Converse APIs +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum AmazonBedrockApi { + Converse, + ConverseStream, +} + +impl ApiDefinition for AmazonBedrockApi { + fn endpoint(&self) -> &'static str { + match self { + AmazonBedrockApi::Converse => "/model/{modelId}/converse", + AmazonBedrockApi::ConverseStream => "/model/{modelId}/converse-stream", + } + } + + fn from_endpoint(endpoint: &str) -> Option { + if endpoint.ends_with("/converse") { + Some(AmazonBedrockApi::Converse) + } else if endpoint.ends_with("/converse-stream") { + Some(AmazonBedrockApi::ConverseStream) + } else { + None + } + } + + fn supports_streaming(&self) -> bool { + match self { + AmazonBedrockApi::Converse => false, + AmazonBedrockApi::ConverseStream => true, + } + } + + fn supports_tools(&self) -> bool { + // Converse API has native tool support + true + } + + fn supports_vision(&self) -> bool { + // Converse API has native vision support + true + } + + fn all_variants() -> Vec { + vec![AmazonBedrockApi::Converse, AmazonBedrockApi::ConverseStream] + } +} + +// ============================================================================ +// CONVERSE REQUEST STRUCTURES +// ============================================================================ + +/// Amazon Bedrock Converse request +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseRequest { + /// The model ID or ARN to invoke + pub model_id: String, + /// The messages to send to the model + pub messages: Option>, + /// System prompts that provide instructions or context + pub system: Option>, + /// Inference configuration + #[serde(rename = "inferenceConfig")] + pub inference_config: Option, + /// Tool configuration for function calling + #[serde(rename = "toolConfig")] + pub tool_config: Option, + /// Guardrail configuration + #[serde(rename = "guardrailConfig")] + pub guardrail_config: Option, + /// Additional model-specific request fields + #[serde(rename = "additionalModelRequestFields")] + pub additional_model_request_fields: Option, + /// Additional model response field paths to return + #[serde(rename = "additionalModelResponseFieldPaths")] + pub additional_model_response_field_paths: Option>, + /// Performance configuration + #[serde(rename = "performanceConfig")] + pub performance_config: Option, + /// Prompt variables for Prompt management + #[serde(rename = "promptVariables")] + pub prompt_variables: Option>, + /// Request metadata for filtering logs + #[serde(rename = "requestMetadata")] + pub request_metadata: Option>, + /// Additional custom metadata (for internal use) + pub metadata: Option>, + /// Whether this request should use streaming endpoint (internal field, not serialized) + #[serde(skip)] + pub stream: bool, +} + +impl Default for ConverseRequest { + fn default() -> Self { + Self { + model_id: String::new(), + messages: None, + system: None, + inference_config: None, + tool_config: None, + guardrail_config: None, + additional_model_request_fields: None, + additional_model_response_field_paths: None, + performance_config: None, + prompt_variables: None, + request_metadata: None, + metadata: None, + stream: false, + } + } +} + +/// Amazon Bedrock ConverseStream request (same structure as Converse) +pub type ConverseStreamRequest = ConverseRequest; + +impl ProviderRequest for ConverseRequest { + fn model(&self) -> &str { + &self.model_id + } + + fn set_model(&mut self, model: String) { + self.model_id = model; + } + + fn is_streaming(&self) -> bool { + self.stream + } + + fn extract_messages_text(&self) -> String { + let mut text_parts = Vec::new(); + + // Extract text from messages + if let Some(messages) = &self.messages { + for message in messages { + for content_block in &message.content { + match content_block { + ContentBlock::Text { text } => { + text_parts.push(text.clone()); + } + ContentBlock::GuardContent { guard_content } => { + if let Some(guard_text) = &guard_content.text { + text_parts.push(guard_text.text.clone()); + } + } + _ => {} // Skip non-text content blocks + } + } + } + } + + // Extract text from system prompts + if let Some(system) = &self.system { + for system_block in system { + match system_block { + SystemContentBlock::Text { text } => { + text_parts.push(text.clone()); + } + SystemContentBlock::GuardContent { + text: Some(guard_text), + } => { + text_parts.push(guard_text.text.clone()); + } + SystemContentBlock::GuardContent { text: None } => { + // No text content in this guard content block + } + } + } + } + + text_parts.join(" ") + } + + fn get_recent_user_message(&self) -> Option { + self.messages + .as_ref()? + .iter() + .rev() // Start from the most recent message + .find(|msg| msg.role == ConversationRole::User) + .and_then(|msg| { + // Extract the first text content block from the user message + msg.content.iter().find_map(|content| match content { + ContentBlock::Text { text } => Some(text.clone()), + _ => None, + }) + }) + } + + fn to_bytes(&self) -> Result, ProviderRequestError> { + serde_json::to_vec(self).map_err(|e| ProviderRequestError { + message: format!("Failed to serialize Bedrock request: {}", e), + source: Some(Box::new(e)), + }) + } + + fn metadata(&self) -> &Option> { + &self.metadata + } + + fn remove_metadata_key(&mut self, key: &str) -> bool { + if let Some(ref mut metadata) = self.metadata { + metadata.remove(key).is_some() + } else { + false + } + } +} + +// ============================================================================ +// CONVERSE RESPONSE STRUCTURES +// ============================================================================ + +/// Amazon Bedrock Converse response +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseResponse { + /// The result from the call to Converse + pub output: ConverseOutput, + /// The reason why the model stopped generating output + #[serde(rename = "stopReason")] + pub stop_reason: StopReason, + /// Token usage information + pub usage: BedrockTokenUsage, + /// Metrics for the call + pub metrics: Option, + /// Additional model response fields + #[serde(rename = "additionalModelResponseFields")] + pub additional_model_response_fields: Option, + /// Performance configuration used + #[serde(rename = "performanceConfig")] + pub performance_config: Option, + /// Trace information for guardrails + pub trace: Option, +} + +/// Amazon Bedrock ConverseStream response events +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum ConverseStreamEvent { + MessageStart(MessageStartEvent), + ContentBlockStart(ContentBlockStartEvent), + ContentBlockDelta(ContentBlockDeltaEvent), + ContentBlockStop(ContentBlockStopEvent), + MessageStop(MessageStopEvent), + Metadata(ConverseStreamMetadataEvent), + // Error events + InternalServerException(BedrockException), + ModelStreamErrorException(BedrockException), + ServiceUnavailableException(BedrockException), + ThrottlingException(BedrockException), + ValidationException(BedrockException), +} + +// ============================================================================ +// MESSAGE AND CONTENT STRUCTURES +// ============================================================================ + +/// Message in a conversation +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Message { + /// Role of the message sender (user, assistant) + pub role: ConversationRole, + /// Content blocks in the message + pub content: Vec, +} + +/// Conversation role enumeration +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ConversationRole { + User, + Assistant, +} + +/// Content block in a message +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum ContentBlock { + Text { + text: String, + }, + Image { + image: ImageBlock, + }, + Document { + document: DocumentBlock, + }, + ToolUse { + #[serde(rename = "toolUse")] + tool_use: ToolUseBlock, + }, + ToolResult { + #[serde(rename = "toolResult")] + tool_result: ToolResultBlock, + }, + GuardContent { + #[serde(rename = "guardContent")] + guard_content: GuardContentBlock, + }, +} + +/// Image block structure +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ImageBlock { + pub source: ImageSource, +} + +/// Document block structure +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct DocumentBlock { + pub source: DocumentSource, + pub name: Option, +} + +/// Tool use block structure +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolUseBlock { + #[serde(rename = "toolUseId")] + pub tool_use_id: String, + pub name: String, + pub input: Value, +} + +/// Tool result block structure +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolResultBlock { + #[serde(rename = "toolUseId")] + pub tool_use_id: String, + pub content: Vec, + pub status: Option, +} + +/// Guard content block structure +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct GuardContentBlock { + pub text: Option, +} + +/// System content block for system prompts +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum SystemContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "guardContent")] + GuardContent { text: Option }, +} + +/// Image source for vision capabilities +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum ImageSource { + #[serde(rename = "base64")] + Base64 { + #[serde(rename = "mediaType")] + media_type: String, + data: String, + }, +} + +/// Document source for document processing +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum DocumentSource { + #[serde(rename = "base64")] + Base64 { + #[serde(rename = "mediaType")] + media_type: String, + data: String, + }, +} + +/// Tool result content block +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum ToolResultContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image")] + Image { source: ImageSource }, + #[serde(rename = "json")] + Json { json: Value }, +} + +/// Tool result status +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ToolResultStatus { + Success, + Error, +} + +/// Guard content text with qualifiers +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct GuardContentText { + pub text: String, + pub qualifiers: Option>, +} + +/// Guard content qualifier +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum GuardContentQualifier { + Grounding, + Relevance, + Harmfulness, + Helpfulness, +} + +// ============================================================================ +// INFERENCE AND TOOL CONFIGURATION +// ============================================================================ + +/// Inference configuration for the model +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct InferenceConfiguration { + /// Maximum tokens to generate + #[serde(rename = "maxTokens")] + pub max_tokens: Option, + /// Temperature for randomness (0.0 to 1.0) + pub temperature: Option, + /// Top-p sampling parameter (0.0 to 1.0) + #[serde(rename = "topP")] + pub top_p: Option, + /// Stop sequences to halt generation + #[serde(rename = "stopSequences")] + pub stop_sequences: Option>, +} + +/// Tool configuration for function calling +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolConfiguration { + /// Available tools for the model + pub tools: Option>, + /// Tool choice configuration + #[serde(rename = "toolChoice")] + pub tool_choice: Option, +} + +/// Tool definition +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum Tool { + ToolSpec { + #[serde(rename = "toolSpec")] + tool_spec: ToolSpecDefinition, + }, +} + +/// Tool specification definition +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolSpecDefinition { + pub name: String, + pub description: Option, + #[serde(rename = "inputSchema")] + pub input_schema: ToolInputSchema, +} + +/// Tool input schema +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolInputSchema { + pub json: Value, +} + +/// Tool choice configuration +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum ToolChoice { + Auto { + #[serde(rename = "auto")] + auto: AutoChoice, + }, + Any { + #[serde(rename = "any")] + any: AnyChoice, + }, + Tool { + #[serde(rename = "tool")] + tool: ToolChoiceSpec, + }, +} + +/// Auto tool choice (empty object) +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct AutoChoice {} + +/// Any tool choice (empty object) +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct AnyChoice {} + +/// Specific tool choice +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolChoiceSpec { + pub name: String, +} + +// ============================================================================ +// GUARDRAIL CONFIGURATION +// ============================================================================ + +/// Guardrail configuration for content filtering +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct GuardrailConfiguration { + /// Guardrail identifier + #[serde(rename = "guardrailIdentifier")] + pub guardrail_identifier: String, + /// Guardrail version + #[serde(rename = "guardrailVersion")] + pub guardrail_version: String, + /// Trace setting + pub trace: Option, +} + +/// Guardrail configuration for streaming (has additional field) +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct GuardrailStreamConfiguration { + /// Guardrail identifier + #[serde(rename = "guardrailIdentifier")] + pub guardrail_identifier: String, + /// Guardrail version + #[serde(rename = "guardrailVersion")] + pub guardrail_version: String, + /// Stream processing mode + #[serde(rename = "streamProcessingMode")] + pub stream_processing_mode: Option, + /// Trace setting + pub trace: Option, +} + +/// Guardrail trace setting +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "UPPERCASE")] +pub enum GuardrailTrace { + Enabled, + Disabled, +} + +// ============================================================================ +// PERFORMANCE CONFIGURATION +// ============================================================================ + +/// Performance configuration for latency optimization +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct PerformanceConfiguration { + /// Latency optimization setting + pub latency: Option, +} + +/// Performance latency setting +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum PerformanceLatency { + Standard, + Optimized, +} + +// ============================================================================ +// RESPONSE OUTPUT STRUCTURES +// ============================================================================ + +/// Converse output (union type) +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum ConverseOutput { + Message { message: Message }, +} + +/// Stop reason enumeration +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum StopReason { + EndTurn, + ToolUse, + MaxTokens, + StopSequence, + GuardrailIntervened, + ContentFiltered, +} + +/// Token usage information for Bedrock Converse API +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +pub struct BedrockTokenUsage { + /// Input tokens processed + #[serde(rename = "inputTokens")] + pub input_tokens: u32, + /// Output tokens generated + #[serde(rename = "outputTokens")] + pub output_tokens: u32, + /// Total tokens used + #[serde(rename = "totalTokens")] + pub total_tokens: u32, + /// Server tool usage (for function calling) + #[serde(rename = "serverToolUsage")] + pub server_tool_usage: Option, + /// Cache read input tokens + #[serde(rename = "cacheReadInputTokens")] + pub cache_read_input_tokens: Option, + /// Cache write input tokens + #[serde(rename = "cacheWriteInputTokens")] + pub cache_write_input_tokens: Option, +} + +/// Converse metrics +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseMetrics { + /// Latency in milliseconds + #[serde(rename = "latencyMs")] + pub latency_ms: u64, +} + +/// Converse trace information +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseTrace { + /// Guardrail trace information + pub guardrail: Option, + /// Prompt router trace information + #[serde(rename = "promptRouter")] + pub prompt_router: Option, +} + +/// Guardrail trace assessment (simplified) +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct GuardrailTraceAssessment { + /// Action reason + #[serde(rename = "actionReason")] + pub action_reason: Option, + /// Model output + #[serde(rename = "modelOutput")] + pub model_output: Option>, + /// Input assessment + #[serde(rename = "inputAssessment")] + pub input_assessment: Option>, + /// Output assessments + #[serde(rename = "outputAssessments")] + pub output_assessments: Option>>, +} + +/// Prompt router trace +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct PromptRouterTrace { + /// Invoked model ID + #[serde(rename = "invokedModelId")] + pub invoked_model_id: String, +} + +// ============================================================================ +// STREAMING EVENT STRUCTURES +// ============================================================================ + +/// Message start event +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessageStartEvent { + /// Role of the message + pub role: ConversationRole, +} + +/// Content block start event +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ContentBlockStartEvent { + /// Content block index + #[serde(rename = "contentBlockIndex")] + pub content_block_index: i32, + /// Start information + pub start: ContentBlockStart, +} + +/// Content block start information +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum ContentBlockStart { + ToolUse { + #[serde(rename = "toolUse")] + tool_use: ToolUseStart, + }, +} + +/// Tool use start information +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolUseStart { + #[serde(rename = "toolUseId")] + pub tool_use_id: String, + pub name: String, +} + +/// Content block delta event +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ContentBlockDeltaEvent { + /// Content block index + #[serde(rename = "contentBlockIndex")] + pub content_block_index: i32, + /// Delta information + pub delta: ContentBlockDelta, +} + +/// Content block delta information +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum ContentBlockDelta { + Text { + text: String, + }, + ToolUse { + #[serde(rename = "toolUse")] + tool_use: ToolUseDelta, + }, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolUseDelta { + pub input: String, +} + +/// Content block stop event +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ContentBlockStopEvent { + /// Content block index + #[serde(rename = "contentBlockIndex")] + pub content_block_index: i32, +} + +/// Message stop event +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessageStopEvent { + /// Stop reason + #[serde(rename = "stopReason")] + pub stop_reason: StopReason, + /// Additional model response fields + #[serde(rename = "additionalModelResponseFields")] + pub additional_model_response_fields: Option, +} + +/// Stream metadata event +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseStreamMetadataEvent { + /// Token usage + pub usage: BedrockTokenUsage, + /// Stream metrics + pub metrics: Option, + /// Trace information + pub trace: Option, + /// Performance configuration + #[serde(rename = "performanceConfig")] + pub performance_config: Option, +} + +/// Stream metrics +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseStreamMetrics { + /// Latency in milliseconds + #[serde(rename = "latencyMs")] + pub latency_ms: u64, +} + +/// Stream trace information +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ConverseStreamTrace { + /// Guardrail trace + pub guardrail: Option, + /// Prompt router trace + #[serde(rename = "promptRouter")] + pub prompt_router: Option, +} + +// ============================================================================ +// PROMPT MANAGEMENT +// ============================================================================ + +/// Prompt variable values for Prompt management +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum PromptVariableValues { + Text { text: String }, +} + +// ============================================================================ +// ERROR TYPES +// ============================================================================ + +/// Bedrock exception structure +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct BedrockException { + /// Exception message + pub message: Option, + /// Original status code (for model errors) + #[serde(rename = "originalStatusCode")] + pub original_status_code: Option, + /// Resource name (for model errors) + #[serde(rename = "resourceName")] + pub resource_name: Option, + /// Original message (for stream errors) + #[serde(rename = "originalMessage")] + pub original_message: Option, +} + +/// Bedrock-specific error types +#[derive(Error, Debug)] +pub enum BedrockError { + #[error("Access denied: {message}")] + AccessDenied { message: String }, + + #[error("Internal server error: {message}")] + InternalServer { message: String }, + + #[error("Model error: {message}")] + ModelError { + message: String, + original_status_code: Option, + resource_name: Option, + }, + + #[error("Model not ready: {message}")] + ModelNotReady { message: String }, + + #[error("Model timeout: {message}")] + ModelTimeout { message: String }, + + #[error("Resource not found: {message}")] + ResourceNotFound { message: String }, + + #[error("Service unavailable: {message}")] + ServiceUnavailable { message: String }, + + #[error("Throttling: {message}")] + Throttling { message: String }, + + #[error("Validation error: {message}")] + Validation { message: String }, + + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), +} + +// ============================================================================ +// TRAIT IMPLEMENTATIONS +// ============================================================================ + +// Note: Trait implementations will be added later when we implement transformations +// For now, we're focusing on modeling the request/response shapes + +impl crate::providers::response::TokenUsage for BedrockTokenUsage { + fn completion_tokens(&self) -> usize { + self.output_tokens as usize + } + + fn prompt_tokens(&self) -> usize { + self.input_tokens as usize + } + + fn total_tokens(&self) -> usize { + self.total_tokens as usize + } +} + +// ============================================================================ +// EVENT STREAM PARSING +// ============================================================================ + +/// Convert from aws-smithy-eventstream DecodedFrame to ConverseStreamEvent +impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEvent { + type Error = BedrockError; + + fn try_from(frame: &aws_smithy_eventstream::frame::DecodedFrame) -> Result { + // Only process Complete frames, skip Incomplete + let message = match frame { + aws_smithy_eventstream::frame::DecodedFrame::Complete(msg) => msg, + aws_smithy_eventstream::frame::DecodedFrame::Incomplete => { + return Err(BedrockError::Validation { + message: "Expected Complete frame, got Incomplete".to_string(), + }) + } + }; + + // Extract the :event-type and :message-type headers + let event_type = message + .headers() + .iter() + .find(|h| h.name().as_str() == ":event-type") + .and_then(|h| h.value().as_string().ok()) + .ok_or_else(|| BedrockError::Validation { + message: "Missing :event-type header".to_string(), + })? + .as_str(); + + let message_type = message + .headers() + .iter() + .find(|h| h.name().as_str() == ":message-type") + .and_then(|h| h.value().as_string().ok()) + .ok_or_else(|| BedrockError::Validation { + message: "Missing :message-type header".to_string(), + })? + .as_str(); + + let payload = message.payload(); + + // Parse the event based on message type and event type + match message_type { + "event" => match event_type { + "messageStart" => { + let event: MessageStartEvent = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::MessageStart(event)) + } + "contentBlockStart" => { + let event: ContentBlockStartEvent = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ContentBlockStart(event)) + } + "contentBlockDelta" => { + let event: ContentBlockDeltaEvent = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ContentBlockDelta(event)) + } + "contentBlockStop" => { + let event: ContentBlockStopEvent = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ContentBlockStop(event)) + } + "messageStop" => { + let event: MessageStopEvent = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::MessageStop(event)) + } + "metadata" => { + let event: ConverseStreamMetadataEvent = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::Metadata(event)) + } + unknown => Err(BedrockError::Validation { + message: format!("Unknown event type: {}", unknown), + }), + }, + "exception" => match event_type { + "internalServerException" => { + let exception: BedrockException = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::InternalServerException(exception)) + } + "modelStreamErrorException" => { + let exception: BedrockException = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ModelStreamErrorException(exception)) + } + "serviceUnavailableException" => { + let exception: BedrockException = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ServiceUnavailableException(exception)) + } + "throttlingException" => { + let exception: BedrockException = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ThrottlingException(exception)) + } + "validationException" => { + let exception: BedrockException = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ValidationException(exception)) + } + unknown => Err(BedrockError::Validation { + message: format!("Unknown exception type: {}", unknown), + }), + }, + unknown => Err(BedrockError::Validation { + message: format!("Unknown message type: {}", unknown), + }), + } + } +} + +impl Into for ConverseStreamEvent { + fn into(self) -> String { + let transformed_json = serde_json::to_string(&self).unwrap_or_default(); + let event_type = match &self { + ConverseStreamEvent::MessageStart { .. } => "message_start", + ConverseStreamEvent::ContentBlockStart { .. } => "content_block_start", + ConverseStreamEvent::ContentBlockDelta { .. } => "content_block_delta", + ConverseStreamEvent::ContentBlockStop { .. } => "content_block_stop", + ConverseStreamEvent::MessageStop { .. } => "message_stop", + ConverseStreamEvent::Metadata { .. } => "metadata", + ConverseStreamEvent::InternalServerException { .. } => "internal_server_exception", + ConverseStreamEvent::ModelStreamErrorException { .. } => "model_stream_error_exception", + ConverseStreamEvent::ServiceUnavailableException { .. } => { + "service_unavailable_exception" + } + ConverseStreamEvent::ThrottlingException { .. } => "throttling_exception", + ConverseStreamEvent::ValidationException { .. } => "validation_exception", + }; + + let event = format!("event: {}\n", event_type); + let data = format!("data: {}\n\n", transformed_json); + event + &data + } +} + +// Implement ProviderStreamResponse for ConverseStreamEvent +impl ProviderStreamResponse for ConverseStreamEvent { + fn content_delta(&self) -> Option<&str> { + match self { + ConverseStreamEvent::ContentBlockDelta(event) => match &event.delta { + ContentBlockDelta::Text { text } => Some(text), + ContentBlockDelta::ToolUse { .. } => None, + }, + _ => None, + } + } + + fn is_final(&self) -> bool { + matches!(self, ConverseStreamEvent::MessageStop(_)) + } + + fn role(&self) -> Option<&str> { + match self { + ConverseStreamEvent::MessageStart(event) => Some(event.role.as_str()), + _ => None, + } + } + + fn event_type(&self) -> Option<&str> { + Some(match self { + ConverseStreamEvent::MessageStart(_) => "messageStart", + ConverseStreamEvent::ContentBlockStart(_) => "contentBlockStart", + ConverseStreamEvent::ContentBlockDelta(_) => "contentBlockDelta", + ConverseStreamEvent::ContentBlockStop(_) => "contentBlockStop", + ConverseStreamEvent::MessageStop(_) => "messageStop", + ConverseStreamEvent::Metadata(_) => "metadata", + ConverseStreamEvent::InternalServerException(_) => "internalServerException", + ConverseStreamEvent::ModelStreamErrorException(_) => "modelStreamErrorException", + ConverseStreamEvent::ServiceUnavailableException(_) => "serviceUnavailableException", + ConverseStreamEvent::ThrottlingException(_) => "throttlingException", + ConverseStreamEvent::ValidationException(_) => "validationException", + }) + } +} + +// Add as_str helper for ConversationRole +impl ConversationRole { + pub fn as_str(&self) -> &'static str { + match self { + ConversationRole::User => "user", + ConversationRole::Assistant => "assistant", + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_tool_serialization_format() { + let tool = Tool::ToolSpec { + tool_spec: ToolSpecDefinition { + name: "get_weather".to_string(), + description: Some("Get the current weather for a specified city".to_string()), + input_schema: ToolInputSchema { + json: json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to get weather for" + } + }, + "required": ["city"] + }), + }, + }, + }; + + let serialized = serde_json::to_value(&tool).unwrap(); + println!( + "Tool serialization: {}", + serde_json::to_string_pretty(&serialized).unwrap() + ); + + // Verify the structure matches Bedrock API expectations + assert!(serialized.get("toolSpec").is_some()); + assert!(serialized.get("type").is_none()); // Should not have a type field + + let tool_spec = serialized.get("toolSpec").unwrap(); + assert_eq!(tool_spec.get("name").unwrap(), "get_weather"); + assert_eq!( + tool_spec.get("description").unwrap(), + "Get the current weather for a specified city" + ); + assert!(tool_spec.get("inputSchema").is_some()); + } + + #[test] + fn test_tool_choice_serialization_format() { + // Test Auto choice + let auto_choice = ToolChoice::Auto { + auto: AutoChoice {}, + }; + let serialized = serde_json::to_value(&auto_choice).unwrap(); + println!( + "Auto ToolChoice serialization: {}", + serde_json::to_string_pretty(&serialized).unwrap() + ); + + assert!(serialized.get("auto").is_some()); + assert!(serialized.get("type").is_none()); // Should not have a type field + + // Test Tool choice + let tool_choice = ToolChoice::Tool { + tool: ToolChoiceSpec { + name: "get_weather".to_string(), + }, + }; + let serialized = serde_json::to_value(&tool_choice).unwrap(); + println!( + "Tool ToolChoice serialization: {}", + serde_json::to_string_pretty(&serialized).unwrap() + ); + + assert!(serialized.get("tool").is_some()); + assert!(serialized.get("type").is_none()); // Should not have a type field + + let tool_spec = serialized.get("tool").unwrap(); + assert_eq!(tool_spec.get("name").unwrap(), "get_weather"); + } +} diff --git a/crates/hermesllm/src/apis/amazon_bedrock_binary_frame.rs b/crates/hermesllm/src/apis/amazon_bedrock_binary_frame.rs new file mode 100644 index 00000000..bacbad62 --- /dev/null +++ b/crates/hermesllm/src/apis/amazon_bedrock_binary_frame.rs @@ -0,0 +1,65 @@ +use aws_smithy_eventstream::frame::DecodedFrame; +use aws_smithy_eventstream::frame::MessageFrameDecoder; +use bytes::Buf; +use std::collections::HashSet; + +/// AWS Event Stream frame decoder wrapper +pub struct BedrockBinaryFrameDecoder +where + B: Buf, +{ + decoder: MessageFrameDecoder, + buffer: B, + content_block_start_indices: HashSet, +} + +impl BedrockBinaryFrameDecoder { + /// This is a convenience constructor that creates a BytesMut buffer internally + pub fn from_bytes(bytes: &[u8]) -> Self { + let buffer = bytes::BytesMut::from(bytes); + Self { + decoder: MessageFrameDecoder::new(), + buffer, + content_block_start_indices: std::collections::HashSet::new(), + } + } +} + +impl BedrockBinaryFrameDecoder +where + B: Buf, +{ + pub fn new(buffer: B) -> Self { + Self { + decoder: MessageFrameDecoder::new(), + buffer, + content_block_start_indices: HashSet::new(), + } + } + + pub fn decode_frame(&mut self) -> Option { + match self.decoder.decode_frame(&mut self.buffer) { + Ok(frame) => Some(frame), + Err(_e) => None, // Fatal decode error + } + } + + pub fn buffer_mut(&mut self) -> &mut B { + &mut self.buffer + } + + /// Check if there are any bytes remaining in the buffer + pub fn has_remaining(&self) -> bool { + self.buffer.has_remaining() + } + + /// Check if a content_block_start event has been sent for the given index + pub fn has_content_block_start_been_sent(&self, index: i32) -> bool { + self.content_block_start_indices.contains(&index) + } + + /// Mark that a content_block_start event has been sent for the given index + pub fn set_content_block_start_sent(&mut self, index: i32) { + self.content_block_start_indices.insert(index); + } +} diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index a261be3c..f91b381c 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -5,9 +5,9 @@ use serde_with::skip_serializing_none; use std::collections::HashMap; use super::ApiDefinition; -use crate::clients::transformer::ExtractText; use crate::providers::request::{ProviderRequest, ProviderRequestError}; use crate::providers::response::{ProviderResponse, ProviderStreamResponse}; +use crate::transforms::lib::ExtractText; use crate::MESSAGES_PATH; // Enum for all supported Anthropic APIs diff --git a/crates/hermesllm/src/apis/mod.rs b/crates/hermesllm/src/apis/mod.rs index 99158dfa..7d84e3ab 100644 --- a/crates/hermesllm/src/apis/mod.rs +++ b/crates/hermesllm/src/apis/mod.rs @@ -1,7 +1,19 @@ +pub mod amazon_bedrock; +pub mod amazon_bedrock_binary_frame; pub mod anthropic; pub mod openai; -pub use anthropic::*; -pub use openai::*; +pub mod sse; + +// Explicit exports to avoid naming conflicts +pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest}; +pub use amazon_bedrock::{ + Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice, +}; +pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent}; +pub use openai::{ + ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse, OpenAIApi, +}; +pub use openai::{Message as OpenAIMessage, Tool as OpenAITool, ToolChoice as OpenAIToolChoice}; pub trait ApiDefinition { /// Returns the endpoint path for this API diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 58e4c8a5..82c5d1a1 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -6,9 +6,9 @@ use std::fmt::Display; use thiserror::Error; use super::ApiDefinition; -use crate::clients::transformer::ExtractText; use crate::providers::request::{ProviderRequest, ProviderRequestError}; use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage}; +use crate::transforms::lib::ExtractText; use crate::CHAT_COMPLETIONS_PATH; // ============================================================================ diff --git a/crates/hermesllm/src/apis/sse.rs b/crates/hermesllm/src/apis/sse.rs new file mode 100644 index 00000000..b8a9b492 --- /dev/null +++ b/crates/hermesllm/src/apis/sse.rs @@ -0,0 +1,196 @@ +use crate::providers::response::ProviderStreamResponse; +use crate::providers::response::ProviderStreamResponseType; +use serde::{Deserialize, Serialize}; +use std::error::Error; +use std::fmt; +use std::str::FromStr; + +// ============================================================================ +// SSE EVENT CONTAINER +// ============================================================================ + +/// Represents a single Server-Sent Event with the complete wire format +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SseEvent { + #[serde(rename = "data")] + pub data: Option, // The JSON payload after "data: " + + #[serde(skip_serializing_if = "Option::is_none")] + pub event: Option, // Optional event type (e.g., "message_start", "content_block_delta") + + #[serde(skip_serializing, skip_deserializing)] + pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n" + + #[serde(skip_serializing, skip_deserializing)] + pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n" + + #[serde(skip_serializing, skip_deserializing)] + pub provider_stream_response: Option, // Parsed provider stream response object +} + +impl SseEvent { + /// Check if this event represents the end of the stream + pub fn is_done(&self) -> bool { + self.data == Some("[DONE]".into()) + } + + /// Check if this event should be skipped during processing + /// This includes ping messages and other provider-specific events that don't contain content + pub fn should_skip(&self) -> bool { + // Skip ping messages (commonly used by providers for connection keep-alive) + self.data == Some(r#"{"type": "ping"}"#.into()) + } + + /// Check if this is an event-only SSE event (no data payload) + pub fn is_event_only(&self) -> bool { + self.event.is_some() && self.data.is_none() + } + + /// Get the parsed provider response if available + pub fn provider_response(&self) -> Result<&dyn ProviderStreamResponse, std::io::Error> { + self.provider_stream_response + .as_ref() + .map(|resp| resp as &dyn ProviderStreamResponse) + .ok_or_else(|| { + std::io::Error::new(std::io::ErrorKind::NotFound, "Provider response not found") + }) + } +} + +impl FromStr for SseEvent { + type Err = SseParseError; + + fn from_str(line: &str) -> Result { + if line.starts_with("data: ") { + let data: String = line[6..].to_string(); // Remove "data: " prefix + if data.is_empty() { + return Err(SseParseError { + message: "Empty data field is not a valid SSE event".to_string(), + }); + } + Ok(SseEvent { + data: Some(data), + event: None, + raw_line: line.to_string(), + sse_transform_buffer: line.to_string(), + provider_stream_response: None, + }) + } else if line.starts_with("event: ") { + //used by Anthropic + let event_type = line[7..].to_string(); + if event_type.is_empty() { + return Err(SseParseError { + message: "Empty event field is not a valid SSE event".to_string(), + }); + } + Ok(SseEvent { + data: None, + event: Some(event_type), + raw_line: line.to_string(), + sse_transform_buffer: line.to_string(), + provider_stream_response: None, + }) + } else { + Err(SseParseError { + message: format!("Line does not start with 'data: ' or 'event: ': {}", line), + }) + } + } +} + +impl fmt::Display for SseEvent { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.sse_transform_buffer) + } +} + +// Into implementation to convert SseEvent to bytes for response buffer +impl Into> for SseEvent { + fn into(self) -> Vec { + format!("{}\n\n", self.sse_transform_buffer).into_bytes() + } +} + +#[derive(Debug)] +pub struct SseParseError { + pub message: String, +} + +impl fmt::Display for SseParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SSE parse error: {}", self.message) + } +} + +impl Error for SseParseError {} + +/// Generic SSE (Server-Sent Events) streaming iterator container +/// Parses raw SSE lines into SseEvent objects +pub struct SseStreamIter +where + I: Iterator, + I::Item: AsRef, +{ + pub lines: I, + pub done_seen: bool, +} + +impl SseStreamIter +where + I: Iterator, + I::Item: AsRef, +{ + pub fn new(lines: I) -> Self { + Self { + lines, + done_seen: false, + } + } +} + +// TryFrom implementation to parse bytes into SseStreamIter +// Handles both text-based SSE and binary AWS Event Stream formats +impl TryFrom<&[u8]> for SseStreamIter> { + type Error = Box; + + fn try_from(bytes: &[u8]) -> Result { + // Parse as text-based SSE format + let s = std::str::from_utf8(bytes)?; + let lines: Vec = s.lines().map(|line| line.to_string()).collect(); + Ok(SseStreamIter::new(lines.into_iter())) + } +} + +impl Iterator for SseStreamIter +where + I: Iterator, + I::Item: AsRef, +{ + type Item = SseEvent; + + fn next(&mut self) -> Option { + // If we already returned [DONE], terminate the stream + if self.done_seen { + return None; + } + + for line in &mut self.lines { + let line_str = line.as_ref(); + + // Try to parse as either data: or event: line + if let Ok(event) = line_str.parse::() { + // For data: lines, check if this is the [DONE] marker + if event.data.is_some() && event.is_done() { + self.done_seen = true; + return Some(event); // Return [DONE] event for transformation + } + // For data: lines, skip events that should be filtered at the transport layer + if event.data.is_some() && event.should_skip() { + continue; + } + return Some(event); + } + } + None + } +} diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index 263ca674..4e23c942 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -1,30 +1,5 @@ -//! Supported endpoint registry for LLM APIs -//! -//! This module provides a simple registry to check which API endpoint paths -//! we support across different providers. -//! -//! # Examples -//! -//! ```rust -//! use hermesllm::clients::endpoints::supported_endpoints; -//! -//! // Check if we support an endpoint -//! use hermesllm::clients::endpoints::SupportedAPIs; -//! assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some()); -//! assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some()); -//! assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some()); -//! -//! // Get all supported endpoints -//! let endpoints = supported_endpoints(); -//! assert_eq!(endpoints.len(), 2); -//! assert!(endpoints.contains(&"/v1/chat/completions")); -//! assert!(endpoints.contains(&"/v1/messages")); -//! ``` - -use crate::{ - apis::{AnthropicApi, ApiDefinition, OpenAIApi}, - ProviderId, -}; +use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, OpenAIApi}; +use crate::ProviderId; use std::fmt; /// Unified enum representing all supported API endpoints across providers @@ -34,6 +9,14 @@ pub enum SupportedAPIs { AnthropicMessagesAPI(AnthropicApi), } +#[derive(Debug, Clone, PartialEq)] +pub enum SupportedUpstreamAPIs { + OpenAIChatCompletions(OpenAIApi), + AnthropicMessagesAPI(AnthropicApi), + AmazonBedrockConverse(AmazonBedrockApi), + AmazonBedrockConverseStream(AmazonBedrockApi), +} + impl fmt::Display for SupportedAPIs { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -74,11 +57,21 @@ impl SupportedAPIs { provider_id: &ProviderId, request_path: &str, model_id: &str, + is_streaming: bool, ) -> String { let default_endpoint = "/v1/chat/completions".to_string(); match self { SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id { ProviderId::Anthropic => "/v1/messages".to_string(), + ProviderId::AmazonBedrock => { + if request_path.starts_with("/v1/") && !is_streaming { + format!("/model/{}/converse", model_id) + } else if request_path.starts_with("/v1/") && is_streaming { + format!("/model/{}/converse-stream", model_id) + } else { + default_endpoint + } + } _ => default_endpoint, }, _ => match provider_id { @@ -117,6 +110,17 @@ impl SupportedAPIs { default_endpoint } } + ProviderId::AmazonBedrock => { + if request_path.starts_with("/v1/") { + if !is_streaming { + format!("/model/{}/converse", model_id) + } else { + format!("/model/{}/converse-stream", model_id) + } + } else { + default_endpoint + } + } _ => default_endpoint, }, } @@ -161,7 +165,6 @@ mod tests { fn test_is_supported_endpoint() { // OpenAI endpoints assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some()); - // Anthropic endpoints assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some()); @@ -174,7 +177,7 @@ mod tests { #[test] fn test_supported_endpoints() { let endpoints = supported_endpoints(); - assert_eq!(endpoints.len(), 2); + assert_eq!(endpoints.len(), 2); // We have 2 APIs defined assert!(endpoints.contains(&"/v1/chat/completions")); assert!(endpoints.contains(&"/v1/messages")); } @@ -217,7 +220,6 @@ mod tests { endpoint ); } - // Total should match assert_eq!( endpoints.len(), diff --git a/crates/hermesllm/src/clients/transformer.rs b/crates/hermesllm/src/clients/transformer.rs index 11caae6f..cbb9bbe7 100644 --- a/crates/hermesllm/src/clients/transformer.rs +++ b/crates/hermesllm/src/clients/transformer.rs @@ -1,1164 +1,6 @@ -//! API request/response transformers between Anthropic and OpenAI APIs -//! -//! This module provides clean, bidirectional conversion between different LLM API formats -//! using Rust's standard `TryFrom` and `Into` traits. The organization follows a logical flow: -//! -//! 1. **Main Request Transformations** - Core TryFrom implementations for requests -//! 2. **Main Response Transformations** - Core TryFrom implementations for responses -//! 3. **Streaming Transformations** - Bidirectional streaming event conversion -//! 4. **Standard Rust Trait Implementations** - Into/TryFrom implementations for type conversions -//! 5. **Helper Functions** - Utility functions organized by domain -//! -//! # Examples -//! -//! ```rust -//! use hermesllm::apis::{ -//! MessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage, -//! MessagesMessageContent, MessagesSystemPrompt, -//! }; -//! use hermesllm::clients::TransformError; -//! use std::convert::TryInto; -//! -//! // Transform Anthropic to OpenAI -//! let anthropic_req = MessagesRequest { -//! model: "claude-3-sonnet".to_string(), -//! system: None, -//! messages: vec![], -//! max_tokens: 1024, -//! container: None, -//! mcp_servers: None, -//! service_tier: None, -//! thinking: None, -//! temperature: None, -//! top_p: None, -//! top_k: None, -//! stream: None, -//! stop_sequences: None, -//! tools: None, -//! tool_choice: None, -//! metadata: None, -//! }; -//! let openai_req: Result = anthropic_req.try_into(); -//! # Ok::<(), Box>(()) -//! ``` +// Re-export new transformation modules for backward compatibility -use super::TransformError; -use crate::apis::*; -use serde_json::Value; -use std::time::{SystemTime, UNIX_EPOCH}; - -// ============================================================================ -// CONSTANTS -// ============================================================================ - -/// Default maximum tokens when converting from OpenAI to Anthropic and no max_tokens is specified -const DEFAULT_MAX_TOKENS: u32 = 4096; - -// ============================================================================ -// UTILITY TRAITS - Shared traits for content manipulation -// ============================================================================ - -/// Trait for extracting text content from various types -pub trait ExtractText { - fn extract_text(&self) -> String; -} - -/// Trait for utility functions on content collections -trait ContentUtils { - fn extract_tool_calls(&self) -> Result>, TransformError>; - fn split_for_openai( - &self, - ) -> Result<(Vec, Vec, Vec<(String, String, bool)>), TransformError>; -} - -// ============================================================================ -// MAIN REQUEST TRANSFORMATIONS -// ============================================================================ - -type AnthropicMessagesRequest = MessagesRequest; - -impl TryFrom for ChatCompletionsRequest { - type Error = TransformError; - - fn try_from(req: AnthropicMessagesRequest) -> Result { - let mut openai_messages: Vec = Vec::new(); - - // Convert system prompt to system message if present - if let Some(system) = req.system { - openai_messages.push(system.into()); - } - - // Convert messages - for message in req.messages { - let converted_messages: Vec = message.try_into()?; - openai_messages.extend(converted_messages); - } - - // Convert tools and tool choice - let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools)); - let (openai_tool_choice, parallel_tool_calls) = - convert_anthropic_tool_choice(req.tool_choice); - - let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest { - model: req.model, - messages: openai_messages, - temperature: req.temperature, - top_p: req.top_p, - max_completion_tokens: Some(req.max_tokens), - stream: req.stream, - stop: req.stop_sequences, - tools: openai_tools, - tool_choice: openai_tool_choice, - parallel_tool_calls, - ..Default::default() - }; - _chat_completions_req.suppress_max_tokens_if_o3(); - _chat_completions_req.fix_temperature_if_gpt5(); - Ok(_chat_completions_req) - } -} - -impl TryFrom for AnthropicMessagesRequest { - type Error = TransformError; - - fn try_from(req: ChatCompletionsRequest) -> Result { - let mut system_prompt = None; - let mut messages = Vec::new(); - - for message in req.messages { - match message.role { - Role::System => { - system_prompt = Some(message.into()); - } - _ => { - let anthropic_message: MessagesMessage = message.try_into()?; - messages.push(anthropic_message); - } - } - } - - // Convert tools and tool choice - let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools)); - let anthropic_tool_choice = - convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls); - - Ok(AnthropicMessagesRequest { - model: req.model, - system: system_prompt, - messages, - max_tokens: req - .max_completion_tokens - .or(req.max_tokens) - .unwrap_or(DEFAULT_MAX_TOKENS), - container: None, - mcp_servers: None, - service_tier: None, - thinking: None, - temperature: req.temperature, - top_p: req.top_p, - top_k: None, // OpenAI doesn't have top_k - stream: req.stream, - stop_sequences: req.stop, - tools: anthropic_tools, - tool_choice: anthropic_tool_choice, - metadata: None, - }) - } -} - -// ============================================================================ -// MAIN RESPONSE TRANSFORMATIONS -// ============================================================================ - -impl TryFrom for ChatCompletionsResponse { - type Error = TransformError; - - fn try_from(resp: MessagesResponse) -> Result { - let content = convert_anthropic_content_to_openai(&resp.content)?; - let finish_reason: FinishReason = resp.stop_reason.into(); - let tool_calls = resp.content.extract_tool_calls()?; - - // Convert MessageContent to String for response - let content_string = match content { - MessageContent::Text(text) => Some(text), - MessageContent::Parts(parts) => { - let text = parts.extract_text(); - if text.is_empty() { - None - } else { - Some(text) - } - } - }; - - let message = ResponseMessage { - role: Role::Assistant, - content: content_string, - refusal: None, - annotations: None, - audio: None, - function_call: None, - tool_calls, - }; - - let choice = Choice { - index: 0, - message, - finish_reason: Some(finish_reason), - logprobs: None, - }; - - let usage = Usage { - prompt_tokens: resp.usage.input_tokens, - completion_tokens: resp.usage.output_tokens, - total_tokens: resp.usage.input_tokens + resp.usage.output_tokens, - prompt_tokens_details: None, - completion_tokens_details: None, - }; - - Ok(ChatCompletionsResponse { - id: resp.id, - object: Some("chat.completion".to_string()), - created: current_timestamp(), - model: resp.model, - choices: vec![choice], - usage, - system_fingerprint: None, - service_tier: None, - }) - } -} - -impl TryFrom for MessagesResponse { - type Error = TransformError; - - fn try_from(resp: ChatCompletionsResponse) -> Result { - let choice = resp - .choices - .into_iter() - .next() - .ok_or_else(|| TransformError::MissingField("choices".to_string()))?; - - let content = convert_openai_message_to_anthropic_content(&choice.message.to_message())?; - let stop_reason = choice - .finish_reason - .map(|fr| fr.into()) - .unwrap_or(MessagesStopReason::EndTurn); - - let usage = MessagesUsage { - input_tokens: resp.usage.prompt_tokens, - output_tokens: resp.usage.completion_tokens, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }; - - Ok(MessagesResponse { - id: resp.id, - obj_type: "message".to_string(), - role: MessagesRole::Assistant, - content, - model: resp.model, - stop_reason, - stop_sequence: None, - usage, - container: None, - }) - } -} - -// ============================================================================ -// STREAMING TRANSFORMATIONS -// ============================================================================ - -impl TryFrom for ChatCompletionsStreamResponse { - type Error = TransformError; - - fn try_from(event: MessagesStreamEvent) -> Result { - match event { - MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk( - &message.id, - &message.model, - MessageDelta { - role: Some(Role::Assistant), - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )), - - MessagesStreamEvent::ContentBlockStart { content_block, .. } => { - convert_content_block_start(content_block) - } - - MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta), - - MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()), - - MessagesStreamEvent::MessageDelta { delta, usage } => { - let finish_reason: Option = Some(delta.stop_reason.into()); - let openai_usage: Option = Some(usage.into()); - - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - finish_reason, - openai_usage, - )) - } - - MessagesStreamEvent::MessageStop => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - Some(FinishReason::Stop), - None, - )), - - MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse { - id: "stream".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: current_timestamp(), - model: "unknown".to_string(), - choices: vec![], - usage: None, - system_fingerprint: None, - service_tier: None, - }), - } - } -} - -impl TryFrom for MessagesStreamEvent { - type Error = TransformError; - - fn try_from(resp: ChatCompletionsStreamResponse) -> Result { - if resp.choices.is_empty() { - return Ok(MessagesStreamEvent::Ping); - } - - let choice = &resp.choices[0]; - - // Handle final chunk with usage - let has_usage = resp.usage.is_some(); - if let Some(usage) = resp.usage { - if let Some(finish_reason) = &choice.finish_reason { - let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); - return Ok(MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: anthropic_stop_reason, - stop_sequence: None, - }, - usage: usage.into(), - }); - } - } - - // Handle role start - if let Some(Role::Assistant) = choice.delta.role { - return Ok(MessagesStreamEvent::MessageStart { - message: MessagesStreamMessage { - id: resp.id, - obj_type: "message".to_string(), - role: MessagesRole::Assistant, - content: vec![], - model: resp.model, - stop_reason: None, - stop_sequence: None, - usage: MessagesUsage { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }, - }); - } - - // Handle content delta - if let Some(content) = &choice.delta.content { - if !content.is_empty() { - return Ok(MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::TextDelta { - text: content.clone(), - }, - }); - } - } - - // Handle tool calls - if let Some(tool_calls) = &choice.delta.tool_calls { - return convert_tool_call_deltas(tool_calls.clone()); - } - - // Handle finish reason - generate MessageDelta only (MessageStop comes later) - if let Some(finish_reason) = &choice.finish_reason { - // If we have usage data, it was already handled above - // If not, we need to generate MessageDelta with default usage - if !has_usage { - let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); - return Ok(MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: anthropic_stop_reason, - stop_sequence: None, - }, - usage: MessagesUsage { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }); - } - // If usage was already handled above, we don't need to do anything more here - // MessageStop will be handled when [DONE] is encountered - } - - // Default to ping for unhandled cases - Ok(MessagesStreamEvent::Ping) - } -} - -// ============================================================================ -// STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for conversions -// ============================================================================ - -// System Prompt Conversions -impl Into for MessagesSystemPrompt { - fn into(self) -> Message { - let system_content = match self { - MessagesSystemPrompt::Single(text) => MessageContent::Text(text), - MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()), - }; - - Message { - role: Role::System, - content: system_content, - name: None, - tool_calls: None, - tool_call_id: None, - } - } -} - -impl Into for Message { - fn into(self) -> MessagesSystemPrompt { - let system_text = match self.content { - MessageContent::Text(text) => text, - MessageContent::Parts(parts) => parts.extract_text(), - }; - MessagesSystemPrompt::Single(system_text) - } -} - -// Message Conversions -impl TryFrom for Vec { - type Error = TransformError; - - fn try_from(message: MessagesMessage) -> Result { - let mut result = Vec::new(); - - match message.content { - MessagesMessageContent::Single(text) => { - result.push(Message { - role: message.role.into(), - content: MessageContent::Text(text), - name: None, - tool_calls: None, - tool_call_id: None, - }); - } - MessagesMessageContent::Blocks(blocks) => { - let (content_parts, tool_calls, tool_results) = blocks.split_for_openai()?; - // Add tool result messages - for (tool_use_id, result_text, _is_error) in tool_results { - result.push(Message { - role: Role::Tool, - content: MessageContent::Text(result_text), - name: None, - tool_calls: None, - tool_call_id: Some(tool_use_id), - }); - } - - // Only create main message if there's actual content or tool calls - // Skip creating empty content messages (e.g., when message only contains tool_result blocks) - if !content_parts.is_empty() || !tool_calls.is_empty() { - let content = build_openai_content(content_parts, &tool_calls); - let main_message = Message { - role: message.role.into(), - content, - name: None, - tool_calls: if tool_calls.is_empty() { - None - } else { - Some(tool_calls) - }, - tool_call_id: None, - }; - result.push(main_message); - } - } - } - - Ok(result) - } -} - -impl TryFrom for MessagesMessage { - type Error = TransformError; - - fn try_from(message: Message) -> Result { - let role = match message.role { - Role::User => MessagesRole::User, - Role::Assistant => MessagesRole::Assistant, - Role::Tool => { - // Tool messages become user messages with tool results - let tool_call_id = message.tool_call_id.ok_or_else(|| { - TransformError::MissingField( - "tool_call_id required for Tool messages".to_string(), - ) - })?; - - return Ok(MessagesMessage { - role: MessagesRole::User, - content: MessagesMessageContent::Blocks(vec![ - MessagesContentBlock::ToolResult { - tool_use_id: tool_call_id, - is_error: None, - content: ToolResultContent::Blocks(vec![MessagesContentBlock::Text { - text: message.content.extract_text(), - cache_control: None, - }]), - cache_control: None, - }, - ]), - }); - } - Role::System => { - return Err(TransformError::UnsupportedConversion( - "System messages should be handled separately".to_string(), - )); - } - }; - - let content_blocks = convert_openai_message_to_anthropic_content(&message)?; - let content = build_anthropic_content(content_blocks); - - Ok(MessagesMessage { role, content }) - } -} - -// Role Conversions -impl Into for MessagesRole { - fn into(self) -> Role { - match self { - MessagesRole::User => Role::User, - MessagesRole::Assistant => Role::Assistant, - } - } -} - -// Content Utilities -impl ContentUtils for Vec { - fn extract_tool_calls(&self) -> Result>, TransformError> { - let mut tool_calls = Vec::new(); - - for block in self { - match block { - MessagesContentBlock::ToolUse { - id, name, input, .. - } - | MessagesContentBlock::ServerToolUse { id, name, input } - | MessagesContentBlock::McpToolUse { id, name, input } => { - let arguments = serde_json::to_string(&input)?; - tool_calls.push(ToolCall { - id: id.clone(), - call_type: "function".to_string(), - function: FunctionCall { - name: name.clone(), - arguments, - }, - }); - } - _ => continue, - } - } - - Ok(if tool_calls.is_empty() { - None - } else { - Some(tool_calls) - }) - } - - fn split_for_openai( - &self, - ) -> Result<(Vec, Vec, Vec<(String, String, bool)>), TransformError> - { - let mut content_parts = Vec::new(); - let mut tool_calls = Vec::new(); - let mut tool_results = Vec::new(); - - for block in self { - match block { - MessagesContentBlock::Text { text, .. } => { - content_parts.push(ContentPart::Text { text: text.clone() }); - } - MessagesContentBlock::Image { source } => { - let url = convert_image_source_to_url(source); - content_parts.push(ContentPart::ImageUrl { - image_url: ImageUrl { - url, - detail: Some("auto".to_string()), - }, - }); - } - MessagesContentBlock::ToolUse { - id, name, input, .. - } - | MessagesContentBlock::ServerToolUse { id, name, input } - | MessagesContentBlock::McpToolUse { id, name, input } => { - let arguments = serde_json::to_string(&input)?; - tool_calls.push(ToolCall { - id: id.clone(), - call_type: "function".to_string(), - function: FunctionCall { - name: name.clone(), - arguments, - }, - }); - } - MessagesContentBlock::ToolResult { - tool_use_id, - content, - is_error, - .. - } => { - let result_text = content.extract_text(); - tool_results.push(( - tool_use_id.clone(), - result_text, - is_error.unwrap_or(false), - )); - } - MessagesContentBlock::WebSearchToolResult { - tool_use_id, - content, - is_error, - } - | MessagesContentBlock::CodeExecutionToolResult { - tool_use_id, - content, - is_error, - } - | MessagesContentBlock::McpToolResult { - tool_use_id, - content, - is_error, - } => { - let result_text = content.extract_text(); - tool_results.push(( - tool_use_id.clone(), - result_text, - is_error.unwrap_or(false), - )); - } - _ => { - // Skip unsupported content types - continue; - } - } - } - - Ok((content_parts, tool_calls, tool_results)) - } -} - -// Stop Reason Conversions -impl Into for MessagesStopReason { - fn into(self) -> FinishReason { - match self { - MessagesStopReason::EndTurn => FinishReason::Stop, - MessagesStopReason::MaxTokens => FinishReason::Length, - MessagesStopReason::StopSequence => FinishReason::Stop, - MessagesStopReason::ToolUse => FinishReason::ToolCalls, - MessagesStopReason::PauseTurn => FinishReason::Stop, - MessagesStopReason::Refusal => FinishReason::ContentFilter, - } - } -} - -impl Into for FinishReason { - fn into(self) -> MessagesStopReason { - match self { - FinishReason::Stop => MessagesStopReason::EndTurn, - FinishReason::Length => MessagesStopReason::MaxTokens, - FinishReason::ToolCalls => MessagesStopReason::ToolUse, - FinishReason::ContentFilter => MessagesStopReason::Refusal, - FinishReason::FunctionCall => MessagesStopReason::ToolUse, - } - } -} - -// Usage Conversions -impl Into for MessagesUsage { - fn into(self) -> Usage { - Usage { - prompt_tokens: self.input_tokens, - completion_tokens: self.output_tokens, - total_tokens: self.input_tokens + self.output_tokens, - prompt_tokens_details: None, - completion_tokens_details: None, - } - } -} - -impl Into for Usage { - fn into(self) -> MessagesUsage { - MessagesUsage { - input_tokens: self.prompt_tokens, - output_tokens: self.completion_tokens, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - } - } -} - -// ============================================================================ -// HELPER FUNCTIONS - Organized by domain -// ============================================================================ - -/// Helper to create a current unix timestamp -fn current_timestamp() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() -} - -/// Helper to create OpenAI streaming chunk -fn create_openai_chunk( - id: &str, - model: &str, - delta: MessageDelta, - finish_reason: Option, - usage: Option, -) -> ChatCompletionsStreamResponse { - ChatCompletionsStreamResponse { - id: id.to_string(), - object: Some("chat.completion.chunk".to_string()), - created: current_timestamp(), - model: model.to_string(), - choices: vec![StreamChoice { - index: 0, - delta, - finish_reason, - logprobs: None, - }], - usage, - system_fingerprint: None, - service_tier: None, - } -} - -/// Helper to create empty OpenAI streaming chunk -fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse { - create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - ) -} - -/// Convert Anthropic tools to OpenAI format -fn convert_anthropic_tools(tools: Vec) -> Vec { - tools - .into_iter() - .map(|tool| Tool { - tool_type: "function".to_string(), - function: Function { - name: tool.name, - description: tool.description, - parameters: tool.input_schema, - strict: None, - }, - }) - .collect() -} - -/// Convert OpenAI tools to Anthropic format -fn convert_openai_tools(tools: Vec) -> Vec { - tools - .into_iter() - .map(|tool| MessagesTool { - name: tool.function.name, - description: tool.function.description, - input_schema: tool.function.parameters, - }) - .collect() -} - -/// Convert Anthropic tool choice to OpenAI format -fn convert_anthropic_tool_choice( - tool_choice: Option, -) -> (Option, Option) { - match tool_choice { - Some(choice) => { - let openai_choice = match choice.kind { - MessagesToolChoiceType::Auto => ToolChoice::Type(ToolChoiceType::Auto), - MessagesToolChoiceType::Any => ToolChoice::Type(ToolChoiceType::Required), - MessagesToolChoiceType::None => ToolChoice::Type(ToolChoiceType::None), - MessagesToolChoiceType::Tool => { - if let Some(name) = choice.name { - ToolChoice::Function { - choice_type: "function".to_string(), - function: FunctionChoice { name }, - } - } else { - ToolChoice::Type(ToolChoiceType::Auto) - } - } - }; - let parallel = choice.disable_parallel_tool_use.map(|disable| !disable); - (Some(openai_choice), parallel) - } - None => (None, None), - } -} - -/// Convert OpenAI tool choice to Anthropic format -fn convert_openai_tool_choice( - tool_choice: Option, - parallel_tool_calls: Option, -) -> Option { - tool_choice.map(|choice| match choice { - ToolChoice::Type(tool_type) => match tool_type { - ToolChoiceType::Auto => MessagesToolChoice { - kind: MessagesToolChoiceType::Auto, - name: None, - disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), - }, - ToolChoiceType::Required => MessagesToolChoice { - kind: MessagesToolChoiceType::Any, - name: None, - disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), - }, - ToolChoiceType::None => MessagesToolChoice { - kind: MessagesToolChoiceType::None, - name: None, - disable_parallel_tool_use: None, - }, - }, - ToolChoice::Function { function, .. } => MessagesToolChoice { - kind: MessagesToolChoiceType::Tool, - name: Some(function.name), - disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), - }, - }) -} - -/// Build OpenAI message content from parts and tool calls -fn build_openai_content( - content_parts: Vec, - tool_calls: &[ToolCall], -) -> MessageContent { - if content_parts.len() == 1 && tool_calls.is_empty() { - match &content_parts[0] { - ContentPart::Text { text } => MessageContent::Text(text.clone()), - _ => MessageContent::Parts(content_parts), - } - } else if content_parts.is_empty() { - MessageContent::Text("".to_string()) - } else { - MessageContent::Parts(content_parts) - } -} - -/// Build Anthropic message content from content blocks -fn build_anthropic_content(content_blocks: Vec) -> MessagesMessageContent { - if content_blocks.len() == 1 { - match &content_blocks[0] { - MessagesContentBlock::Text { text, .. } => MessagesMessageContent::Single(text.clone()), - _ => MessagesMessageContent::Blocks(content_blocks), - } - } else if content_blocks.is_empty() { - MessagesMessageContent::Single("".to_string()) - } else { - MessagesMessageContent::Blocks(content_blocks) - } -} - -/// Convert Anthropic content blocks to OpenAI message content -fn convert_anthropic_content_to_openai( - content: &[MessagesContentBlock], -) -> Result { - let mut text_parts = Vec::new(); - - for block in content { - match block { - MessagesContentBlock::Text { text, .. } => { - text_parts.push(text.clone()); - } - MessagesContentBlock::Thinking { thinking, .. } => { - text_parts.push(format!("thinking: {}", thinking)); - } - _ => { - // Skip other content types for basic text conversion - continue; - } - } - } - - Ok(MessageContent::Text(text_parts.join("\n"))) -} - -/// Convert OpenAI message to Anthropic content blocks -fn convert_openai_message_to_anthropic_content( - message: &Message, -) -> Result, TransformError> { - let mut blocks = Vec::new(); - - // Handle regular content - match &message.content { - MessageContent::Text(text) => { - if !text.is_empty() { - blocks.push(MessagesContentBlock::Text { - text: text.clone(), - cache_control: None, - }); - } - } - MessageContent::Parts(parts) => { - for part in parts { - match part { - ContentPart::Text { text } => { - blocks.push(MessagesContentBlock::Text { - text: text.clone(), - cache_control: None, - }); - } - ContentPart::ImageUrl { image_url } => { - let source = convert_image_url_to_source(image_url); - blocks.push(MessagesContentBlock::Image { source }); - } - } - } - } - } - - // Handle tool calls - if let Some(tool_calls) = &message.tool_calls { - for tool_call in tool_calls { - let input: Value = serde_json::from_str(&tool_call.function.arguments)?; - blocks.push(MessagesContentBlock::ToolUse { - id: tool_call.id.clone(), - name: tool_call.function.name.clone(), - input, - cache_control: None, - }); - } - } - - Ok(blocks) -} - -/// Convert image source to URL -fn convert_image_source_to_url(source: &MessagesImageSource) -> String { - match source { - MessagesImageSource::Base64 { media_type, data } => { - format!("data:{};base64,{}", media_type, data) - } - MessagesImageSource::Url { url } => url.clone(), - } -} - -/// Convert image URL to Anthropic image source -fn convert_image_url_to_source(image_url: &ImageUrl) -> MessagesImageSource { - if image_url.url.starts_with("data:") { - // Parse data URL - let parts: Vec<&str> = image_url.url.splitn(2, ',').collect(); - if parts.len() == 2 { - let header = parts[0]; - let data = parts[1]; - let media_type = header - .strip_prefix("data:") - .and_then(|s| s.split(';').next()) - .unwrap_or("image/jpeg") - .to_string(); - - MessagesImageSource::Base64 { - media_type, - data: data.to_string(), - } - } else { - MessagesImageSource::Url { - url: image_url.url.clone(), - } - } - } else { - MessagesImageSource::Url { - url: image_url.url.clone(), - } - } -} - -/// Convert content block start to OpenAI chunk -fn convert_content_block_start( - content_block: MessagesContentBlock, -) -> Result { - match content_block { - MessagesContentBlock::Text { .. } => { - // No immediate output for text block start - Ok(create_empty_openai_chunk()) - } - MessagesContentBlock::ToolUse { id, name, .. } - | MessagesContentBlock::ServerToolUse { id, name, .. } - | MessagesContentBlock::McpToolUse { id, name, .. } => { - // Tool use start → OpenAI chunk with tool_calls - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: 0, - id: Some(id), - call_type: Some("function".to_string()), - function: Some(FunctionCallDelta { - name: Some(name), - arguments: Some("".to_string()), - }), - }]), - }, - None, - None, - )) - } - _ => Err(TransformError::UnsupportedContent( - "Unsupported content block type in stream start".to_string(), - )), - } -} - -/// Convert content delta to OpenAI chunk -fn convert_content_delta( - delta: MessagesContentDelta, -) -> Result { - match delta { - MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: Some(text), - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )), - MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: Some(format!("thinking: {}", thinking)), - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )), - MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: 0, - id: None, - call_type: None, - function: Some(FunctionCallDelta { - name: None, - arguments: Some(partial_json), - }), - }]), - }, - None, - None, - )), - } -} - -/// Convert tool call deltas to Anthropic stream events -fn convert_tool_call_deltas( - tool_calls: Vec, -) -> Result { - for tool_call in tool_calls { - if let Some(id) = &tool_call.id { - // Tool call start - if let Some(function) = &tool_call.function { - if let Some(name) = &function.name { - return Ok(MessagesStreamEvent::ContentBlockStart { - index: tool_call.index, - content_block: MessagesContentBlock::ToolUse { - id: id.clone(), - name: name.clone(), - input: Value::Object(serde_json::Map::new()), - cache_control: None, - }, - }); - } - } - } else if let Some(function) = &tool_call.function { - if let Some(arguments) = &function.arguments { - // Tool arguments delta - return Ok(MessagesStreamEvent::ContentBlockDelta { - index: tool_call.index, - delta: MessagesContentDelta::InputJsonDelta { - partial_json: arguments.clone(), - }, - }); - } - } - } - - // Fallback to ping if no valid tool call found - Ok(MessagesStreamEvent::Ping) -} +//KEEPING THE TESTS TO MAKE SURE ALL THE REFACTORING DIDN'T BREAK ANYTHING // ============================================================================ // TESTS @@ -1166,8 +8,11 @@ fn convert_tool_call_deltas( #[cfg(test)] mod tests { - use super::*; + use crate::apis::anthropic::*; + use crate::apis::openai::*; + use crate::transforms::*; use serde_json::json; + type AnthropicMessagesRequest = MessagesRequest; #[test] fn test_anthropic_to_openai_basic_request() { diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 2789947b..77289f4b 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -4,12 +4,16 @@ pub mod apis; pub mod clients; pub mod providers; +pub mod transforms; // Re-export important types and traits +pub use apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; +pub use apis::sse::{SseEvent, SseStreamIter}; +pub use aws_smithy_eventstream::frame::DecodedFrame; pub use providers::id::ProviderId; pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType}; pub use providers::response::{ ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse, - ProviderStreamResponseType, SseEvent, SseStreamIter, TokenUsage, + ProviderStreamResponseType, TokenUsage, }; //TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings @@ -18,6 +22,8 @@ pub const MESSAGES_PATH: &str = "/v1/messages"; #[cfg(test)] mod tests { + use crate::clients::endpoints::SupportedUpstreamAPIs; + use super::*; #[test] @@ -40,7 +46,7 @@ mod tests { let client_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); let upstream_api = - SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); + SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); // Test the new simplified architecture - create SseStreamIter directly let sse_iter = SseStreamIter::try_from(sse_data.as_bytes()); @@ -77,4 +83,156 @@ mod tests { let final_event = streaming_iter.next(); assert!(final_event.is_none()); // Should be None because iterator stops at [DONE] } + + /// Test AWS Event Stream decoding for Bedrock ConverseStream responses. + /// + /// This test demonstrates how to: + /// 1. Use MessageFrameDecoder to decode AWS Event Stream frames + /// 2. Handle chunked network arrivals with buffering + /// 3. Extract event types from message headers + /// 4. Parse JSON payloads from decoded messages + /// 5. Reconstruct streaming content from contentBlockDelta events + /// + /// The decoder handles frame boundaries automatically - you just keep calling + /// decode_frame() until it returns Incomplete, which means you've processed + /// all complete frames in the buffer. + #[test] + fn test_amazon_bedrock_streaming_response() { + use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder}; + use bytes::{Buf, BytesMut}; + use std::fs; + use std::path::PathBuf; + + // Read the response.hex file from tests/e2e directory + // Use absolute path to avoid cargo test working directory issues + let test_file = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); + let response_data = fs::read(&test_file) + .unwrap_or_else(|e| panic!("Failed to read {:?}: {}", test_file, e)); + + println!("📊 Response data size: {} bytes\n", response_data.len()); + + // Create decoder and buffer that implements Buf trait + // BytesMut automatically tracks position as decoder advances it! + let mut decoder = MessageFrameDecoder::new(); + let mut simulated_network_buffer = BytesMut::new(); + let mut frame_count = 0; + let mut content_chunks = Vec::new(); + + // Simulate chunked network arrivals - process as data comes in + let chunk_sizes = vec![50, 100, 75, 200, 150, 300, 500, 1000]; + let mut offset = 0; + let mut chunk_num = 0; + + println!("🔄 Simulating chunked network arrivals...\n"); + + // Process chunks as they "arrive" from the network + while offset < response_data.len() { + // Receive next chunk from network + let chunk_size = chunk_sizes[chunk_num % chunk_sizes.len()]; + let end = (offset + chunk_size).min(response_data.len()); + let chunk = &response_data[offset..end]; + + chunk_num += 1; + simulated_network_buffer.extend_from_slice(chunk); + offset = end; + + println!( + "📦 Chunk {}: Received {} bytes (buffer: {} bytes total, {} bytes remaining)", + chunk_num, + chunk.len(), + simulated_network_buffer.len(), + simulated_network_buffer.remaining() + ); + + // Try to decode all complete frames from buffer + // The Buf trait tracks position automatically! + loop { + let bytes_before = simulated_network_buffer.remaining(); + match decoder.decode_frame(&mut simulated_network_buffer) { + Ok(DecodedFrame::Complete(message)) => { + frame_count += 1; + let consumed = bytes_before - simulated_network_buffer.remaining(); + + println!( + " ✅ Frame {}: decoded ({} bytes, {} bytes remaining)", + frame_count, + consumed, + simulated_network_buffer.remaining() + ); + + // Get event type from headers + let event_type = message + .headers() + .iter() + .find(|h| h.name().as_str() == ":event-type") + .and_then(|h| { + h.value().as_string().ok().map(|s| s.as_str().to_string()) + }); + + if let Some(ref evt) = event_type { + println!(" Event: {}", evt); + } + + // Parse payload and extract content + let payload = message.payload(); + if !payload.is_empty() { + if let Ok(json) = serde_json::from_slice::(payload) { + if event_type.as_deref() == Some("contentBlockDelta") { + if let Some(delta) = json.get("delta") { + if let Some(text) = + delta.get("text").and_then(|t| t.as_str()) + { + println!(" 📝 Content: \"{}\"", text); + content_chunks.push(text.to_string()); + } + } + } + } + } // Continue loop to check for more complete frames in buffer + } + Ok(DecodedFrame::Incomplete) => { + // Not enough data for a complete frame - need more chunks + println!( + " ⏳ Incomplete frame ({} bytes remaining) - waiting for more data\n", + simulated_network_buffer.remaining() + ); + break; // Wait for next chunk + } + Err(e) => { + panic!("❌ Frame decode error: {}", e); + } + } + } + } + + println!("\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"); + println!("📋 Summary:"); + println!(" Total chunks received: {}", chunk_num); + println!(" Total frames decoded: {}", frame_count); + println!(" Total content chunks: {}", content_chunks.len()); + println!( + " Final buffer remaining: {} bytes", + simulated_network_buffer.remaining() + ); + + if !content_chunks.is_empty() { + let full_text = content_chunks.join(""); + println!("\n📄 Full reconstructed content:"); + println!("{}", full_text); + println!("\n Characters: {}", full_text.len()); + println!(" Estimated tokens: ~{}", full_text.len() / 4); + } + + // Ensure we decoded at least one frame + assert!(frame_count > 0, "Should decode at least one frame"); + + // Ensure all data was consumed - if buffer has remaining bytes, it's a partial frame + assert_eq!( + simulated_network_buffer.remaining(), + 0, + "All bytes should be consumed, {} bytes remain", + simulated_network_buffer.remaining() + ); + } } diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index b898d7d7..94a6205a 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -1,5 +1,5 @@ -use crate::apis::{AnthropicApi, OpenAIApi}; -use crate::clients::endpoints::SupportedAPIs; +use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi}; +use crate::clients::endpoints::{SupportedAPIs, SupportedUpstreamAPIs}; use std::fmt::Display; /// Provider identifier enum - simple enum for identifying providers @@ -19,7 +19,8 @@ pub enum ProviderId { Ollama, Moonshotai, Zhipu, - Qwen, // alias for Qwen + Qwen, + AmazonBedrock, } impl From<&str> for ProviderId { @@ -39,7 +40,8 @@ impl From<&str> for ProviderId { "ollama" => ProviderId::Ollama, "moonshotai" => ProviderId::Moonshotai, "zhipu" => ProviderId::Zhipu, - "qwen" => ProviderId::Qwen, // alias for Zhipu + "qwen" => ProviderId::Qwen, // alias for Qwen + "amazon_bedrock" => ProviderId::AmazonBedrock, _ => panic!("Unknown provider: {}", value), } } @@ -47,16 +49,20 @@ impl From<&str> for ProviderId { impl ProviderId { /// Given a client API, return the compatible upstream API for this provider - pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs) -> SupportedAPIs { + pub fn compatible_api_for_client( + &self, + client_api: &SupportedAPIs, + is_streaming: bool, + ) -> SupportedUpstreamAPIs { match (self, client_api) { // Claude/Anthropic providers natively support Anthropic APIs (ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => { - SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) + SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) } ( ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), - ) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + ) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), // OpenAI-compatible providers only support OpenAI chat completions ( @@ -75,7 +81,7 @@ impl ProviderId { | ProviderId::Zhipu | ProviderId::Qwen, SupportedAPIs::AnthropicMessagesAPI(_), - ) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + ) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), ( ProviderId::OpenAI @@ -93,7 +99,27 @@ impl ProviderId { | ProviderId::Zhipu | ProviderId::Qwen, SupportedAPIs::OpenAIChatCompletions(_), - ) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + ) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + + // Amazon Bedrock natively supports Bedrock APIs + (ProviderId::AmazonBedrock, SupportedAPIs::OpenAIChatCompletions(_)) => { + if is_streaming { + SupportedUpstreamAPIs::AmazonBedrockConverseStream( + AmazonBedrockApi::ConverseStream, + ) + } else { + SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) + } + } + (ProviderId::AmazonBedrock, SupportedAPIs::AnthropicMessagesAPI(_)) => { + if is_streaming { + SupportedUpstreamAPIs::AmazonBedrockConverseStream( + AmazonBedrockApi::ConverseStream, + ) + } else { + SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) + } + } } } } @@ -116,6 +142,7 @@ impl Display for ProviderId { ProviderId::Moonshotai => write!(f, "moonshotai"), ProviderId::Zhipu => write!(f, "zhipu"), ProviderId::Qwen => write!(f, "qwen"), + ProviderId::AmazonBedrock => write!(f, "amazon_bedrock"), } } } diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index 1cee7169..a8bcfa29 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -1,6 +1,9 @@ use crate::apis::anthropic::MessagesRequest; use crate::apis::openai::ChatCompletionsRequest; + +use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest}; use crate::clients::endpoints::SupportedAPIs; +use crate::clients::endpoints::SupportedUpstreamAPIs; use serde_json::Value; use std::collections::HashMap; @@ -10,6 +13,8 @@ use std::fmt; pub enum ProviderRequestType { ChatCompletionsRequest(ChatCompletionsRequest), MessagesRequest(MessagesRequest), + BedrockConverse(ConverseRequest), + BedrockConverseStream(ConverseStreamRequest), //add more request types here } pub trait ProviderRequest: Send + Sync { @@ -42,6 +47,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.model(), Self::MessagesRequest(r) => r.model(), + Self::BedrockConverse(r) => r.model(), + Self::BedrockConverseStream(r) => r.model(), } } @@ -49,6 +56,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.set_model(model), Self::MessagesRequest(r) => r.set_model(model), + Self::BedrockConverse(r) => r.set_model(model), + Self::BedrockConverseStream(r) => r.set_model(model), } } @@ -56,6 +65,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.is_streaming(), Self::MessagesRequest(r) => r.is_streaming(), + Self::BedrockConverse(_) => false, + Self::BedrockConverseStream(_) => true, } } @@ -63,6 +74,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.extract_messages_text(), Self::MessagesRequest(r) => r.extract_messages_text(), + Self::BedrockConverse(r) => r.extract_messages_text(), + Self::BedrockConverseStream(r) => r.extract_messages_text(), } } @@ -70,6 +83,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.get_recent_user_message(), Self::MessagesRequest(r) => r.get_recent_user_message(), + Self::BedrockConverse(r) => r.get_recent_user_message(), + Self::BedrockConverseStream(r) => r.get_recent_user_message(), } } @@ -77,6 +92,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.to_bytes(), Self::MessagesRequest(r) => r.to_bytes(), + Self::BedrockConverse(r) => r.to_bytes(), + Self::BedrockConverseStream(r) => r.to_bytes(), } } @@ -84,6 +101,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.metadata(), Self::MessagesRequest(r) => r.metadata(), + Self::BedrockConverse(r) => r.metadata(), + Self::BedrockConverseStream(r) => r.metadata(), } } @@ -91,6 +110,8 @@ impl ProviderRequest for ProviderRequestType { match self { Self::ChatCompletionsRequest(r) => r.remove_metadata_key(key), Self::MessagesRequest(r) => r.remove_metadata_key(key), + Self::BedrockConverse(r) => r.remove_metadata_key(key), + Self::BedrockConverseStream(r) => r.remove_metadata_key(key), } } } @@ -120,27 +141,27 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType { } /// Conversion from one ProviderRequestType to a different ProviderRequestType (SupportedAPIs) -impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType { +impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestType { type Error = ProviderRequestError; fn try_from( - (request, upstream_api): (ProviderRequestType, &SupportedAPIs), + (client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs), ) -> Result { - match (request, upstream_api) { + match (client_request, upstream_api) { // Same API - no conversion needed, just clone the reference ( ProviderRequestType::ChatCompletionsRequest(chat_req), - SupportedAPIs::OpenAIChatCompletions(_), + SupportedUpstreamAPIs::OpenAIChatCompletions(_), ) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)), ( ProviderRequestType::MessagesRequest(messages_req), - SupportedAPIs::AnthropicMessagesAPI(_), + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), ) => Ok(ProviderRequestType::MessagesRequest(messages_req)), // Cross-API conversion - cloning is necessary for transformation ( ProviderRequestType::ChatCompletionsRequest(chat_req), - SupportedAPIs::AnthropicMessagesAPI(_), + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), ) => { let messages_req = MessagesRequest::try_from(chat_req).map_err(|e| ProviderRequestError { @@ -155,7 +176,7 @@ impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType { ( ProviderRequestType::MessagesRequest(messages_req), - SupportedAPIs::OpenAIChatCompletions(_), + SupportedUpstreamAPIs::OpenAIChatCompletions(_), ) => { let chat_req = ChatCompletionsRequest::try_from(messages_req).map_err(|e| { ProviderRequestError { @@ -168,6 +189,69 @@ impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType { })?; Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) } + + // Cross-API conversions: OpenAI/Anthropic to Amazon Bedrock + ( + ProviderRequestType::ChatCompletionsRequest(chat_req), + SupportedUpstreamAPIs::AmazonBedrockConverse(_), + ) => { + let bedrock_req = ConverseRequest::try_from(chat_req) + .map_err(|e| ProviderRequestError { + message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e), + source: Some(Box::new(e)) + })?; + Ok(ProviderRequestType::BedrockConverse(bedrock_req)) + } + + ( + ProviderRequestType::ChatCompletionsRequest(chat_req), + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + ) => { + let bedrock_req = ConverseStreamRequest::try_from(chat_req) + .map_err(|e| ProviderRequestError { + message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e), + source: Some(Box::new(e)) + })?; + Ok(ProviderRequestType::BedrockConverse(bedrock_req)) + } + ( + ProviderRequestType::MessagesRequest(messages_req), + SupportedUpstreamAPIs::AmazonBedrockConverse(_), + ) => { + let bedrock_req = + ConverseRequest::try_from(messages_req).map_err(|e| ProviderRequestError { + message: format!( + "Failed to convert MessagesRequest to Amazon Bedrock request: {}", + e + ), + source: Some(Box::new(e)), + })?; + Ok(ProviderRequestType::BedrockConverse(bedrock_req)) + } + ( + ProviderRequestType::MessagesRequest(messages_req), + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + ) => { + let bedrock_req = ConverseStreamRequest::try_from(messages_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert MessagesRequest to Amazon Bedrock request: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::BedrockConverse(bedrock_req)) + } + + // Amazon Bedrock to other APIs conversions + (ProviderRequestType::BedrockConverse(_), _) => { + todo!("Amazon Bedrock to ChatCompletionsRequest conversion not implemented yet") + } + + (ProviderRequestType::BedrockConverseStream(_), _) => { + todo!("Amazon Bedrock Stream to ChatCompletionsRequest conversion not implemented yet") + } } } } @@ -201,7 +285,7 @@ mod tests { use crate::apis::openai::ChatCompletionsRequest; use crate::apis::openai::OpenAIApi::ChatCompletions; use crate::clients::endpoints::SupportedAPIs; - use crate::clients::transformer::ExtractText; + use crate::transforms::lib::ExtractText; use serde_json::json; #[test] diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index 5f4607df..c6c37693 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -1,15 +1,18 @@ -use crate::providers::id::ProviderId; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use std::convert::TryFrom; use std::error::Error; use std::fmt; -use std::str::FromStr; +use crate::apis::amazon_bedrock::ConverseResponse; +use crate::apis::amazon_bedrock::ConverseStreamEvent; use crate::apis::anthropic::MessagesResponse; use crate::apis::anthropic::MessagesStreamEvent; use crate::apis::openai::ChatCompletionsResponse; use crate::apis::openai::ChatCompletionsStreamResponse; +use crate::apis::sse::SseEvent; use crate::clients::endpoints::SupportedAPIs; +use crate::clients::endpoints::SupportedUpstreamAPIs; +use crate::providers::id::ProviderId; /// Trait for token usage information pub trait TokenUsage { @@ -30,6 +33,7 @@ pub enum ProviderResponseType { pub enum ProviderStreamResponseType { ChatCompletionsStreamResponse(ChatCompletionsStreamResponse), MessagesStreamEvent(MessagesStreamEvent), + ConverseStreamEvent(ConverseStreamEvent), } pub trait ProviderResponse: Send + Sync { @@ -58,7 +62,6 @@ impl ProviderResponse for ProviderResponseType { } } } - pub trait ProviderStreamResponse: Send + Sync { /// Get the content delta for this chunk fn content_delta(&self) -> Option<&str>; @@ -78,6 +81,7 @@ impl ProviderStreamResponse for ProviderStreamResponseType { match self { ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.content_delta(), ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.content_delta(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.content_delta(), } } @@ -85,6 +89,7 @@ impl ProviderStreamResponse for ProviderStreamResponseType { match self { ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.is_final(), ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.is_final(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.is_final(), } } @@ -92,6 +97,7 @@ impl ProviderStreamResponse for ProviderStreamResponseType { match self { ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.role(), ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.role(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.role(), } } @@ -99,116 +105,31 @@ impl ProviderStreamResponse for ProviderStreamResponseType { match self { ProviderStreamResponseType::ChatCompletionsStreamResponse(_resp) => None, // OpenAI doesn't use event types ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.event_type(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.event_type(), // Bedrock doesn't use event types } } } -// ============================================================================ -// SSE EVENT CONTAINER -// ============================================================================ - -/// Represents a single Server-Sent Event with the complete wire format -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SseEvent { - #[serde(rename = "data")] - pub data: Option, // The JSON payload after "data: " - - #[serde(skip_serializing_if = "Option::is_none")] - pub event: Option, // Optional event type (e.g., "message_start", "content_block_delta") - - #[serde(skip_serializing, skip_deserializing)] - pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n" - - #[serde(skip_serializing, skip_deserializing)] - pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n" - - #[serde(skip_serializing, skip_deserializing)] - pub provider_stream_response: Option, // Parsed provider stream response object -} - -impl SseEvent { - /// Check if this event represents the end of the stream - pub fn is_done(&self) -> bool { - self.data == Some("[DONE]".into()) - } - - /// Check if this event should be skipped during processing - /// This includes ping messages and other provider-specific events that don't contain content - pub fn should_skip(&self) -> bool { - // Skip ping messages (commonly used by providers for connection keep-alive) - self.data == Some(r#"{"type": "ping"}"#.into()) - } - - /// Check if this is an event-only SSE event (no data payload) - pub fn is_event_only(&self) -> bool { - self.event.is_some() && self.data.is_none() - } - - /// Get the parsed provider response if available - pub fn provider_response(&self) -> Result<&dyn ProviderStreamResponse, std::io::Error> { - self.provider_stream_response - .as_ref() - .map(|resp| resp as &dyn ProviderStreamResponse) - .ok_or_else(|| { - std::io::Error::new(std::io::ErrorKind::NotFound, "Provider response not found") - }) - } -} - -impl FromStr for SseEvent { - type Err = SseParseError; - - fn from_str(line: &str) -> Result { - if line.starts_with("data: ") { - let data: String = line[6..].to_string(); // Remove "data: " prefix - if data.is_empty() { - return Err(SseParseError { - message: "Empty data field is not a valid SSE event".to_string(), - }); +impl Into for ProviderStreamResponseType { + fn into(self) -> String { + match self { + ProviderStreamResponseType::MessagesStreamEvent(event) => { + // Use the Into implementation for proper SSE formatting with event lines + event.into() } - Ok(SseEvent { - data: Some(data), - event: None, - raw_line: line.to_string(), - sse_transform_buffer: line.to_string(), - provider_stream_response: None, - }) - } else if line.starts_with("event: ") { - //used by Anthropic - let event_type = line[7..].to_string(); - if event_type.is_empty() { - return Err(SseParseError { - message: "Empty event field is not a valid SSE event".to_string(), - }); + ProviderStreamResponseType::ConverseStreamEvent(event) => { + // Use the Into implementation for proper SSE formatting with event lines + event.into() + } + ProviderStreamResponseType::ChatCompletionsStreamResponse(_) => { + // For OpenAI, use simple data line format + let json = serde_json::to_string(&self).unwrap_or_default(); + format!("data: {}\n\n", json) } - Ok(SseEvent { - data: None, - event: Some(event_type), - raw_line: line.to_string(), - sse_transform_buffer: line.to_string(), - provider_stream_response: None, - }) - } else { - Err(SseParseError { - message: format!("Line does not start with 'data: ' or 'event: ': {}", line), - }) } } } -impl fmt::Display for SseEvent { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.sse_transform_buffer) - } -} - -// Into implementation to convert SseEvent to bytes for response buffer -impl Into> for SseEvent { - fn into(self) -> Vec { - format!("{}\n\n", self.sse_transform_buffer).into_bytes() - } -} - // --- Response transformation logic for client API compatibility --- impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { type Error = std::io::Error; @@ -216,19 +137,28 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { fn try_from( (bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId), ) -> Result { - let upstream_api = provider_id.compatible_api_for_client(client_api); + let upstream_api = provider_id.compatible_api_for_client(client_api, false); match (&upstream_api, client_api) { - (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + ( + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + SupportedAPIs::OpenAIChatCompletions(_), + ) => { let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; Ok(ProviderResponseType::ChatCompletionsResponse(resp)) } - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + ( + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + SupportedAPIs::AnthropicMessagesAPI(_), + ) => { let resp: MessagesResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; Ok(ProviderResponseType::MessagesResponse(resp)) } - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + ( + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + SupportedAPIs::OpenAIChatCompletions(_), + ) => { let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -242,7 +172,10 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { })?; Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp)) } - (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + ( + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + SupportedAPIs::AnthropicMessagesAPI(_), + ) => { let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -255,69 +188,130 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { })?; Ok(ProviderResponseType::MessagesResponse(messages_resp)) } + // Amazon Bedrock transformations + ( + SupportedUpstreamAPIs::AmazonBedrockConverse(_), + SupportedAPIs::OpenAIChatCompletions(_), + ) => { + let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to OpenAI ChatCompletions format using the transformer + let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp)) + } + ( + SupportedUpstreamAPIs::AmazonBedrockConverse(_), + SupportedAPIs::AnthropicMessagesAPI(_), + ) => { + let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to Anthropic Messages format using the transformer + let messages_resp: MessagesResponse = bedrock_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::MessagesResponse(messages_resp)) + } + _ => Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Unsupported API combination for response transformation", + )), } } } // Stream response transformation logic for client API compatibility -impl TryFrom<(&[u8], &SupportedAPIs, &SupportedAPIs)> for ProviderStreamResponseType { +impl TryFrom<(&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)> for ProviderStreamResponseType { type Error = Box; fn try_from( - (bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedAPIs), + (bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedUpstreamAPIs), ) -> Result { + // Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion + if bytes == b"[DONE]" && matches!(client_api, SupportedAPIs::AnthropicMessagesAPI(_)) { + return Ok(ProviderStreamResponseType::MessagesStreamEvent( + crate::apis::anthropic::MessagesStreamEvent::MessageStop, + )); + } match (upstream_api, client_api) { - (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - let resp: crate::apis::openai::ChatCompletionsStreamResponse = - serde_json::from_slice(bytes)?; + // OpenAI upstream + ( + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + SupportedAPIs::OpenAIChatCompletions(_), + ) => { + let resp = serde_json::from_slice(bytes)?; Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( resp, )) } - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - let resp: crate::apis::anthropic::MessagesStreamEvent = - serde_json::from_slice(bytes)?; - Ok(ProviderStreamResponseType::MessagesStreamEvent(resp)) - } - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = - serde_json::from_slice(bytes)?; - - // Transform to OpenAI ChatCompletions stream format using the transformer - let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = - anthropic_resp.try_into()?; - Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( - chat_resp, - )) - } - (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - // Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion - if bytes == b"[DONE]" { - return Ok(ProviderStreamResponseType::MessagesStreamEvent( - crate::apis::anthropic::MessagesStreamEvent::MessageStop, - )); - } - + ( + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + SupportedAPIs::AnthropicMessagesAPI(_), + ) => { let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?; - - // Transform to Anthropic Messages stream format using the transformer - let messages_resp: crate::apis::anthropic::MessagesStreamEvent = - openai_resp.try_into()?; + let anthropic_resp = openai_resp.try_into()?; Ok(ProviderStreamResponseType::MessagesStreamEvent( - messages_resp, + anthropic_resp, )) } + + // Anthropic upstream + ( + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + SupportedAPIs::AnthropicMessagesAPI(_), + ) => { + let resp = serde_json::from_slice(bytes)?; + Ok(ProviderStreamResponseType::MessagesStreamEvent(resp)) + } + ( + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + SupportedAPIs::OpenAIChatCompletions(_), + ) => { + let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = + serde_json::from_slice(bytes)?; + let openai_resp = anthropic_resp.try_into()?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( + openai_resp, + )) + } + + // Amazon Bedrock ConverseStream upstream + ( + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + SupportedAPIs::AnthropicMessagesAPI(_), + ) => { + let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent = + serde_json::from_slice(bytes)?; + let anthropic_resp = bedrock_resp.try_into()?; + Ok(ProviderStreamResponseType::MessagesStreamEvent( + anthropic_resp, + )) + } + _ => Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Unsupported API combination for response transformation", + ) + .into()), } } } // TryFrom implementation to convert raw bytes to SseEvent with parsed provider response -impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent { +impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent { type Error = Box; fn try_from( - (sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedAPIs), + (sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs), ) -> Result { // Create a new transformed event based on the original let mut transformed_event = sse_event; @@ -326,157 +320,132 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent { if transformed_event.data.is_some() { let data_str = transformed_event.data.as_ref().unwrap(); let data_bytes = data_str.as_bytes(); - let transformed_response = + let transformed_response: ProviderStreamResponseType = ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?; - let transformed_json = serde_json::to_string(&transformed_response)?; - transformed_event.sse_transform_buffer = format!("data: {}\n\n", transformed_json); + + // Convert to SSE string explicitly to avoid type ambiguity + let sse_string: String = transformed_response.clone().into(); + transformed_event.sse_transform_buffer = sse_string; transformed_event.provider_stream_response = Some(transformed_response); } match (client_api, upstream_api) { - (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - // No transformation needed - } - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - // No transformation needed - } - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + ( + SupportedAPIs::AnthropicMessagesAPI(_), + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + ) => { if let Some(provider_response) = &transformed_event.provider_stream_response { if let Some(event_type) = provider_response.event_type() { // This ensures the required Anthropic sequence: MessageStart → ContentBlockStart → ContentBlockDelta(s) if event_type == "message_start" { - let content_block_start_json = serde_json::json!({ - "type": "content_block_start", - "index": 0, - "content_block": { - "type": "text", - "text": "" - } - }); + // Create ContentBlockStart event and format it using Into + let content_block_start = MessagesStreamEvent::ContentBlockStart { + index: 0, + content_block: crate::apis::anthropic::MessagesContentBlock::Text { + text: String::new(), + cache_control: None, + }, + }; + let content_block_start_sse: String = content_block_start.into(); + // Format as proper SSE: MessageStart first, then ContentBlockStart + // The sse_transform_buffer already contains the properly formatted MessageStart transformed_event.sse_transform_buffer = format!( - "event: {}\n{}\nevent: content_block_start\ndata: {}\n\n", - event_type, - transformed_event.sse_transform_buffer, - content_block_start_json, + "{}{}", + transformed_event.sse_transform_buffer, content_block_start_sse, ); } else if event_type == "message_delta" { - let content_block_stop_json = serde_json::json!({ - "type": "content_block_stop", - "index": 0 - }); + // Create ContentBlockStop event and format it using Into + let content_block_stop = + MessagesStreamEvent::ContentBlockStop { index: 0 }; + let content_block_stop_sse: String = content_block_stop.into(); + // Format as proper SSE: ContentBlockStop first, then MessageDelta transformed_event.sse_transform_buffer = format!( - "event: content_block_stop\ndata: {}\n\nevent: {}\n{}", - content_block_stop_json, - event_type, - transformed_event.sse_transform_buffer - ); - } else { - transformed_event.sse_transform_buffer = format!( - "event: {}\n{}", - event_type, transformed_event.sse_transform_buffer + "{}{}", + content_block_stop_sse, transformed_event.sse_transform_buffer ); } + // For other event types, the sse_transform_buffer already has the correct format from Into } // If event_type is None, we just keep the data line as-is without an event line // This handles cases where the transformation might not produce a valid event type } } - (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + ( + SupportedAPIs::OpenAIChatCompletions(_), + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + ) => { if transformed_event.is_event_only() && transformed_event.event.is_some() { transformed_event.sse_transform_buffer = format!("\n"); // suppress the event upstream for OpenAI } } + _ => { + // Other combinations can be handled here as needed + } } Ok(transformed_event) } } -#[derive(Debug)] -pub struct SseParseError { - pub message: String, -} - -impl fmt::Display for SseParseError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "SSE parse error: {}", self.message) - } -} - -impl Error for SseParseError {} - -// ============================================================================ -// GENERIC SSE STREAMING ITERATOR (Container Only) -// ============================================================================ - -/// Generic SSE (Server-Sent Events) streaming iterator container -/// Parses raw SSE lines into SseEvent objects -pub struct SseStreamIter -where - I: Iterator, - I::Item: AsRef, +// TryFrom implementation to convert AWS Event Stream DecodedFrame to ProviderStreamResponseType +impl + TryFrom<( + &aws_smithy_eventstream::frame::DecodedFrame, + &SupportedAPIs, + &SupportedUpstreamAPIs, + )> for ProviderStreamResponseType { - pub lines: I, - pub done_seen: bool, -} - -impl SseStreamIter -where - I: Iterator, - I::Item: AsRef, -{ - pub fn new(lines: I) -> Self { - Self { - lines, - done_seen: false, - } - } -} - -// TryFrom implementation to parse bytes into SseStreamIter -impl TryFrom<&[u8]> for SseStreamIter> { type Error = Box; - fn try_from(bytes: &[u8]) -> Result { - let s = std::str::from_utf8(bytes)?; - let lines: Vec = s.lines().map(|line| line.to_string()).collect(); - Ok(SseStreamIter::new(lines.into_iter())) - } -} + fn try_from( + (frame, client_api, upstream_api): ( + &aws_smithy_eventstream::frame::DecodedFrame, + &SupportedAPIs, + &SupportedUpstreamAPIs, + ), + ) -> Result { + use aws_smithy_eventstream::frame::DecodedFrame; -impl Iterator for SseStreamIter -where - I: Iterator, - I::Item: AsRef, -{ - type Item = SseEvent; + match frame { + DecodedFrame::Complete(_) => { + // We have a complete frame - parse it based on upstream API + match (upstream_api, client_api) { + ( + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + SupportedAPIs::AnthropicMessagesAPI(_), + ) => { + // Parse the DecodedFrame into ConverseStreamEvent + let bedrock_event = + crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; + let anthropic_event: crate::apis::anthropic::MessagesStreamEvent = + bedrock_event.try_into()?; - fn next(&mut self) -> Option { - // If we already returned [DONE], terminate the stream - if self.done_seen { - return None; - } - - for line in &mut self.lines { - let line_str = line.as_ref(); - - // Try to parse as either data: or event: line - if let Ok(event) = line_str.parse::() { - // For data: lines, check if this is the [DONE] marker - if event.data.is_some() && event.is_done() { - self.done_seen = true; - return Some(event); // Return [DONE] event for transformation + Ok(ProviderStreamResponseType::MessagesStreamEvent( + anthropic_event, + )) + } + ( + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + SupportedAPIs::OpenAIChatCompletions(_), + ) => { + // Parse the DecodedFrame into ConverseStreamEvent + let bedrock_event = + crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; + let openai_event: crate::apis::openai::ChatCompletionsStreamResponse = + bedrock_event.try_into()?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( + openai_event, + )) + } + _ => Err("Unsupported API combination for event-stream decoding".into()), } - // For data: lines, skip events that should be filtered at the transport layer - if event.data.is_some() && event.should_skip() { - continue; - } - return Some(event); + } + DecodedFrame::Incomplete => { + Err("Cannot convert incomplete frame to provider response".into()) } } - None } } @@ -503,8 +472,10 @@ impl Error for ProviderResponseError { #[cfg(test)] mod tests { use super::*; + use crate::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; use crate::apis::anthropic::AnthropicApi; use crate::apis::openai::OpenAIApi; + use crate::apis::sse::SseStreamIter; use crate::clients::endpoints::SupportedAPIs; use crate::providers::id::ProviderId; use serde_json::json; @@ -869,8 +840,9 @@ mod tests { // Test that [DONE] marker is properly converted to MessageStop in the transformation layer let done_bytes = b"[DONE]"; let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); - let upstream_api = - SupportedAPIs::OpenAIChatCompletions(crate::apis::openai::OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions( + crate::apis::openai::OpenAIApi::ChatCompletions, + ); let result = ProviderStreamResponseType::try_from(( done_bytes.as_slice(), @@ -890,4 +862,707 @@ mod tests { panic!("Expected MessagesStreamEvent::MessageStop"); } } + + #[test] + fn test_bedrock_event_stream_decoder_basic() { + use bytes::BytesMut; + + // Create a simple test with minimal data + let mut buffer = BytesMut::new(); + + // Add some arbitrary bytes (not a real event-stream frame, just for testing the decoder) + buffer.extend_from_slice(b"test data"); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + // The decoder should return Incomplete for incomplete/invalid data + // This signals the caller to wait for more data + let result = decoder.decode_frame(); + assert!(result.is_some()); + assert!(matches!( + result.unwrap(), + aws_smithy_eventstream::frame::DecodedFrame::Incomplete + )); + + // Verify we can still access the buffer + assert!(decoder.has_remaining()); + } + + #[test] + fn test_bedrock_event_stream_decoder_with_real_frames() { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response.hex file from tests/e2e directory + let test_file = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); + + // Only run this test if the file exists + if !test_file.exists() { + println!("Skipping test - response.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + let mut buffer = BytesMut::from(&response_data[..]); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + let mut frame_count = 0; + + // Decode all frames + loop { + match decoder.decode_frame() { + Some(aws_smithy_eventstream::frame::DecodedFrame::Complete(message)) => { + frame_count += 1; + + // Verify we can access headers + let event_type = message + .headers() + .iter() + .find(|h| h.name().as_str() == ":event-type") + .and_then(|h| h.value().as_string().ok()); + + assert!(event_type.is_some(), "Frame should have :event-type header"); + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // End of buffer, no more complete frames available + break; + } + None => { + // Decode error + panic!("Decode error encountered"); + } + } + } + + // We should have decoded multiple frames + assert!(frame_count > 0, "Should have decoded at least one frame"); + } + + #[test] + fn test_bedrock_event_stream_decoder_chunked_data() { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response.hex file + let test_file = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); + + if !test_file.exists() { + println!("Skipping test - response.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + + // Simulate chunked network arrivals with realistic chunk sizes + // Using varying chunk sizes to test partial frame handling + let mut buffer = BytesMut::new(); + let chunk_size_pattern = vec![500, 1000, 750, 1200, 800, 1500]; + let mut offset = 0; + let mut total_frames = 0; + let mut chunk_num = 0; + + // CRITICAL: Create ONE decoder and reuse it across chunks + // The MessageFrameDecoder maintains state about partial frames + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + // Process all data in chunks + while offset < response_data.len() { + let chunk_size = chunk_size_pattern[chunk_num % chunk_size_pattern.len()]; + chunk_num += 1; + + let end = (offset + chunk_size).min(response_data.len()); + let chunk = &response_data[offset..end]; + + // Add new data to the buffer (accessing via buffer_mut()) + decoder.buffer_mut().extend_from_slice(chunk); + offset = end; + + // Process all available complete frames from this chunk + loop { + match decoder.decode_frame() { + Some(aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { + total_frames += 1; + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // Need more data - wait for next chunk + break; + } + None => { + // Decode error + panic!("Decode error in chunked test"); + } + } + } + } + + assert!( + total_frames > 0, + "Should have decoded frames from chunked data" + ); + } + + #[test] + fn test_bedrock_decoded_frame_to_provider_response() { + test_bedrock_conversion(false); + } + + #[test] + #[ignore] // Run with: cargo test -- --ignored --nocapture + fn test_bedrock_decoded_frame_to_provider_response_verbose() { + test_bedrock_conversion(true); + } + + #[test] + fn test_bedrock_decoded_frame_with_tool_use() { + test_bedrock_conversion_with_tools(false); + } + + #[test] + #[ignore] // Run with: cargo test -- --ignored --nocapture + fn test_bedrock_decoded_frame_with_tool_use_verbose() { + test_bedrock_conversion_with_tools(true); + } + + fn test_bedrock_conversion(verbose: bool) { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response.hex file from tests/e2e directory + let test_file = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); + + // Only run this test if the file exists + if !test_file.exists() { + println!("Skipping test - response.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + let mut buffer = BytesMut::from(&response_data[..]); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + let client_api = + SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( + crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, + ); + + let mut conversion_count = 0; + let mut message_start_seen = false; + + // Decode and convert frames + loop { + match decoder.decode_frame() { + Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { + // Convert DecodedFrame to ProviderStreamResponseType + let result = + ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); + + match result { + Ok(provider_response) => { + conversion_count += 1; + + // Verify we got a MessagesStreamEvent + assert!(matches!( + provider_response, + ProviderStreamResponseType::MessagesStreamEvent(_) + )); + + if verbose { + // Print the SSE string output + let sse_string: String = provider_response.clone().into(); + println!("{}", sse_string); + } + + // Check for MessageStart event + if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = + provider_response + { + if matches!( + event, + crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } + ) { + message_start_seen = true; + } + } + } + Err(e) => { + println!("Conversion error (frame {}): {}", conversion_count, e); + } + } + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // End of buffer + break; + } + None => { + panic!("Decode error"); + } + } + } + + assert!( + conversion_count > 0, + "Should have converted at least one frame" + ); + assert!(message_start_seen, "Should have seen MessageStart event"); + } + + fn test_bedrock_conversion_with_tools(verbose: bool) { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response_with_tools.hex file from tests/e2e directory + let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("../../tests/e2e/response_with_tools.hex"); + + // Only run this test if the file exists + if !test_file.exists() { + println!("Skipping test - response_with_tools.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + let mut buffer = BytesMut::from(&response_data[..]); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + let client_api = + SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( + crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, + ); + + let mut conversion_count = 0; + let mut message_start_seen = false; + let mut content_block_start_seen = false; + let mut content_block_delta_tool_use_seen = false; + + // Decode and convert frames + loop { + match decoder.decode_frame() { + Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { + // Convert DecodedFrame to ProviderStreamResponseType + let result = + ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); + + match result { + Ok(provider_response) => { + conversion_count += 1; + + // Verify we got a MessagesStreamEvent + assert!(matches!( + provider_response, + ProviderStreamResponseType::MessagesStreamEvent(_) + )); + + if verbose { + // Print the SSE string output + let sse_string: String = provider_response.clone().into(); + println!("{}", sse_string); + } + + // Check for specific events related to tool use + if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = + provider_response + { + match event { + crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } => { + message_start_seen = true; + } + crate::apis::anthropic::MessagesStreamEvent::ContentBlockStart { .. } => { + content_block_start_seen = true; + } + crate::apis::anthropic::MessagesStreamEvent::ContentBlockDelta { delta, .. } => { + if matches!(delta, crate::apis::anthropic::MessagesContentDelta::InputJsonDelta { .. }) { + content_block_delta_tool_use_seen = true; + } + } + _ => {} + } + } + } + Err(e) => { + println!("Conversion error (frame {}): {}", conversion_count, e); + } + } + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // End of buffer + break; + } + None => { + panic!("Decode error"); + } + } + } + + assert!( + conversion_count > 0, + "Should have converted at least one frame" + ); + assert!(message_start_seen, "Should have seen MessageStart event"); + assert!( + content_block_start_seen, + "Should have seen ContentBlockStart event for tool use" + ); + assert!( + content_block_delta_tool_use_seen, + "Should have seen ContentBlockDelta with ToolUseDelta" + ); + } + + #[test] + fn test_sse_event_transformation_openai_to_anthropic_message_start() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response that represents a role start (which becomes message_start in Anthropic) + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"role": "assistant"}, + "finish_reason": null + }] + }); + + // Create SSE event with this data + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify the transformation includes both message_start and content_block_start + let buffer = transformed.sse_transform_buffer; + assert!( + buffer.contains("event: message_start"), + "Should contain message_start event" + ); + assert!( + buffer.contains("event: content_block_start"), + "Should contain content_block_start event" + ); + + // Verify proper SSE format with event lines before data lines + assert!(buffer.find("event: message_start").unwrap() < buffer.find("data:").unwrap()); + assert!(buffer.find("content_block_start").is_some()); + } + + #[test] + fn test_sse_event_transformation_openai_to_anthropic_message_delta() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response with finish_reason (which becomes message_delta in Anthropic) + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 25, + "total_tokens": 35 + } + }); + + // Create SSE event with this data + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify the transformation includes both content_block_stop and message_delta + let buffer = transformed.sse_transform_buffer; + assert!( + buffer.contains("event: content_block_stop"), + "Should contain content_block_stop event" + ); + assert!( + buffer.contains("event: message_delta"), + "Should contain message_delta event" + ); + + // Verify content_block_stop comes before message_delta + let stop_pos = buffer.find("content_block_stop").unwrap(); + let delta_pos = buffer.find("message_delta").unwrap(); + assert!( + stop_pos < delta_pos, + "content_block_stop should come before message_delta" + ); + } + + #[test] + fn test_sse_event_transformation_openai_to_anthropic_content_delta() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response with content (which becomes content_block_delta in Anthropic) + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"content": "Hello"}, + "finish_reason": null + }] + }); + + // Create SSE event with this data + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify the transformation is a content_block_delta (no extra events injected) + let buffer = transformed.sse_transform_buffer; + assert!( + buffer.contains("event: content_block_delta"), + "Should contain content_block_delta event" + ); + assert!( + !buffer.contains("content_block_start"), + "Should not inject content_block_start for content delta" + ); + assert!( + !buffer.contains("content_block_stop"), + "Should not inject content_block_stop for content delta" + ); + + // Verify the content is preserved + assert!(buffer.contains("Hello"), "Should preserve the content text"); + } + + #[test] + fn test_sse_event_transformation_anthropic_to_openai_suppresses_event_lines() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an Anthropic event-only SSE line (no data) + let sse_event = SseEvent { + data: None, + event: Some("message_start".to_string()), + raw_line: "event: message_start".to_string(), + sse_transform_buffer: "event: message_start".to_string(), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify the event line is suppressed (replaced with just newline) + assert_eq!( + transformed.sse_transform_buffer, "\n", + "Event-only lines should be suppressed to newline for OpenAI" + ); + assert!( + transformed.is_event_only(), + "Should still be marked as event-only" + ); + } + + #[test] + fn test_sse_event_transformation_anthropic_to_openai_preserves_data() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an Anthropic message_start event with data + let anthropic_event = json!({ + "type": "message_start", + "message": { + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-sonnet", + "stop_reason": null, + "usage": {"input_tokens": 10, "output_tokens": 0} + } + }); + + let sse_event = SseEvent { + data: Some(anthropic_event.to_string()), + event: None, + raw_line: format!("data: {}", anthropic_event.to_string()), + sse_transform_buffer: format!("data: {}", anthropic_event.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify data is transformed to OpenAI format + let buffer = transformed.sse_transform_buffer; + assert!(buffer.starts_with("data: "), "Should have data: prefix"); + assert!( + !buffer.contains("event:"), + "Should not have event: lines for OpenAI" + ); + + // Verify provider response was parsed + assert!(transformed.provider_stream_response.is_some()); + } + + #[test] + fn test_sse_event_transformation_no_change_for_matching_apis() { + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"content": "Hello"}, + "finish_reason": null + }] + }); + + let original_data = openai_stream_chunk.to_string(); + let sse_event = SseEvent { + data: Some(original_data.clone()), + event: None, + raw_line: format!("data: {}", original_data), + sse_transform_buffer: format!("data: {}\n\n", original_data), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify minimal transformation - just SSE formatting, no API conversion + let buffer = transformed.sse_transform_buffer; + assert!(buffer.starts_with("data: "), "Should preserve data: prefix"); + assert!(!buffer.contains("event:"), "Should not add event: lines"); + + // Verify provider response was parsed + assert!(transformed.provider_stream_response.is_some()); + } + + #[test] + fn test_sse_event_transformation_preserves_provider_response() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"content": "Test"}, + "finish_reason": null + }] + }); + + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify provider_stream_response is populated + assert!( + transformed.provider_stream_response.is_some(), + "Should parse and store provider response" + ); + + // Verify we can access the provider response + let provider_response = transformed.provider_response(); + assert!( + provider_response.is_ok(), + "Should be able to access provider response" + ); + + // Verify the content delta is accessible + let content = provider_response.unwrap().content_delta(); + assert_eq!(content, Some("Test"), "Should preserve content delta"); + } } diff --git a/crates/hermesllm/src/transforms/lib.rs b/crates/hermesllm/src/transforms/lib.rs new file mode 100644 index 00000000..53a7621e --- /dev/null +++ b/crates/hermesllm/src/transforms/lib.rs @@ -0,0 +1,231 @@ +use crate::apis::anthropic::{MessagesContentBlock, MessagesImageSource}; +use crate::apis::openai::{ContentPart, FunctionCall, ImageUrl, Message, MessageContent, ToolCall}; +use crate::clients::TransformError; +use serde_json::Value; +use std::time::{SystemTime, UNIX_EPOCH}; + +pub trait ExtractText { + fn extract_text(&self) -> String; +} + +/// Trait for utility functions on content collections +pub trait ContentUtils { + fn extract_tool_calls(&self) -> Result>, TransformError>; + fn split_for_openai( + &self, + ) -> Result<(Vec, Vec, Vec<(String, String, bool)>), TransformError>; +} + +/// Helper to create a current unix timestamp +pub fn current_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() +} + +// Content Utilities +impl ContentUtils for Vec { + fn extract_tool_calls(&self) -> Result>, TransformError> { + let mut tool_calls = Vec::new(); + + for block in self { + match block { + MessagesContentBlock::ToolUse { + id, name, input, .. + } + | MessagesContentBlock::ServerToolUse { id, name, input } + | MessagesContentBlock::McpToolUse { id, name, input } => { + let arguments = serde_json::to_string(&input)?; + tool_calls.push(ToolCall { + id: id.clone(), + call_type: "function".to_string(), + function: FunctionCall { + name: name.clone(), + arguments, + }, + }); + } + _ => continue, + } + } + + Ok(if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }) + } + + fn split_for_openai( + &self, + ) -> Result<(Vec, Vec, Vec<(String, String, bool)>), TransformError> + { + let mut content_parts = Vec::new(); + let mut tool_calls = Vec::new(); + let mut tool_results = Vec::new(); + + for block in self { + match block { + MessagesContentBlock::Text { text, .. } => { + content_parts.push(ContentPart::Text { text: text.clone() }); + } + MessagesContentBlock::Image { source } => { + let url = convert_image_source_to_url(source); + content_parts.push(ContentPart::ImageUrl { + image_url: ImageUrl { + url, + detail: Some("auto".to_string()), + }, + }); + } + MessagesContentBlock::ToolUse { + id, name, input, .. + } + | MessagesContentBlock::ServerToolUse { id, name, input } + | MessagesContentBlock::McpToolUse { id, name, input } => { + let arguments = serde_json::to_string(&input)?; + tool_calls.push(ToolCall { + id: id.clone(), + call_type: "function".to_string(), + function: FunctionCall { + name: name.clone(), + arguments, + }, + }); + } + MessagesContentBlock::ToolResult { + tool_use_id, + content, + is_error, + .. + } => { + let result_text = content.extract_text(); + tool_results.push(( + tool_use_id.clone(), + result_text, + is_error.unwrap_or(false), + )); + } + MessagesContentBlock::WebSearchToolResult { + tool_use_id, + content, + is_error, + } + | MessagesContentBlock::CodeExecutionToolResult { + tool_use_id, + content, + is_error, + } + | MessagesContentBlock::McpToolResult { + tool_use_id, + content, + is_error, + } => { + let result_text = content.extract_text(); + tool_results.push(( + tool_use_id.clone(), + result_text, + is_error.unwrap_or(false), + )); + } + _ => { + // Skip unsupported content types + continue; + } + } + } + + Ok((content_parts, tool_calls, tool_results)) + } +} + +/// Convert image source to URL +pub fn convert_image_source_to_url(source: &MessagesImageSource) -> String { + match source { + MessagesImageSource::Base64 { media_type, data } => { + format!("data:{};base64,{}", media_type, data) + } + MessagesImageSource::Url { url } => url.clone(), + } +} + +/// Convert image URL to Anthropic image source +fn convert_image_url_to_source(image_url: &ImageUrl) -> MessagesImageSource { + if image_url.url.starts_with("data:") { + // Parse data URL + let parts: Vec<&str> = image_url.url.splitn(2, ',').collect(); + if parts.len() == 2 { + let header = parts[0]; + let data = parts[1]; + let media_type = header + .strip_prefix("data:") + .and_then(|s| s.split(';').next()) + .unwrap_or("image/jpeg") + .to_string(); + + MessagesImageSource::Base64 { + media_type, + data: data.to_string(), + } + } else { + MessagesImageSource::Url { + url: image_url.url.clone(), + } + } + } else { + MessagesImageSource::Url { + url: image_url.url.clone(), + } + } +} + +/// Convert OpenAI message to Anthropic content blocks +pub fn convert_openai_message_to_anthropic_content( + message: &Message, +) -> Result, TransformError> { + let mut blocks = Vec::new(); + + // Handle regular content + match &message.content { + MessageContent::Text(text) => { + if !text.is_empty() { + blocks.push(MessagesContentBlock::Text { + text: text.clone(), + cache_control: None, + }); + } + } + MessageContent::Parts(parts) => { + for part in parts { + match part { + ContentPart::Text { text } => { + blocks.push(MessagesContentBlock::Text { + text: text.clone(), + cache_control: None, + }); + } + ContentPart::ImageUrl { image_url } => { + let source = convert_image_url_to_source(image_url); + blocks.push(MessagesContentBlock::Image { source }); + } + } + } + } + } + + // Handle tool calls + if let Some(tool_calls) = &message.tool_calls { + for tool_call in tool_calls { + let input: Value = serde_json::from_str(&tool_call.function.arguments)?; + blocks.push(MessagesContentBlock::ToolUse { + id: tool_call.id.clone(), + name: tool_call.function.name.clone(), + input, + cache_control: None, + }); + } + } + + Ok(blocks) +} diff --git a/crates/hermesllm/src/transforms/mod.rs b/crates/hermesllm/src/transforms/mod.rs new file mode 100644 index 00000000..3fb4e397 --- /dev/null +++ b/crates/hermesllm/src/transforms/mod.rs @@ -0,0 +1,25 @@ +//! API transformation modules +//! +//! This module provides organized transformations between the two main LLM API formats: +//! - `/v1/chat/completions` (OpenAI format) +//! - `/v1/messages` (Anthropic format) +//! +//! Provider-specific transformations (Bedrock, Groq, etc.) are handled internally +//! by the gateway, but the external API surface remains these two standard formats. +//! The transformations are split into logical modules for maintainability. + +pub mod lib; +pub mod request; +pub mod response; + +// Re-export commonly used items for convenience +pub use lib::*; +pub use request::*; +pub use response::*; + +// ============================================================================ +// CONSTANTS +// ============================================================================ + +/// Default maximum tokens when converting from OpenAI to Anthropic and no max_tokens is specified +pub const DEFAULT_MAX_TOKENS: u32 = 4096; diff --git a/crates/hermesllm/src/transforms/request/from_anthropic.rs b/crates/hermesllm/src/transforms/request/from_anthropic.rs new file mode 100644 index 00000000..877faaa8 --- /dev/null +++ b/crates/hermesllm/src/transforms/request/from_anthropic.rs @@ -0,0 +1,704 @@ +use crate::apis::amazon_bedrock::{ + AnyChoice, AutoChoice, ContentBlock, ConversationRole, ConverseRequest, ImageBlock, + ImageSource, InferenceConfiguration, Message as BedrockMessage, SystemContentBlock, + Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolChoiceSpec, ToolConfiguration, + ToolInputSchema, ToolResultBlock, ToolResultContentBlock, ToolResultStatus, ToolSpecDefinition, + ToolUseBlock, +}; +use crate::apis::anthropic::{ + MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole, MessagesStopReason, + MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesUsage, + ToolResultContent, +}; +use crate::apis::openai::{ + ChatCompletionsRequest, ContentPart, FinishReason, Function, FunctionChoice, Message, + MessageContent, Role, Tool, ToolCall, ToolChoice, ToolChoiceType, Usage, +}; +use crate::clients::TransformError; +use crate::transforms::lib::*; + +type AnthropicMessagesRequest = MessagesRequest; + +// Conversion from Anthropic MessagesRequest to OpenAI ChatCompletionsRequest +impl TryFrom for ChatCompletionsRequest { + type Error = TransformError; + + fn try_from(req: AnthropicMessagesRequest) -> Result { + let mut openai_messages: Vec = Vec::new(); + + // Convert system prompt to system message if present + if let Some(system) = req.system { + openai_messages.push(system.into()); + } + + // Convert messages + for message in req.messages { + let converted_messages: Vec = message.try_into()?; + openai_messages.extend(converted_messages); + } + + // Convert tools and tool choice + let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools)); + let (openai_tool_choice, parallel_tool_calls) = + convert_anthropic_tool_choice(req.tool_choice); + + let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest { + model: req.model, + messages: openai_messages, + temperature: req.temperature, + top_p: req.top_p, + max_completion_tokens: Some(req.max_tokens), + stream: req.stream, + stop: req.stop_sequences, + tools: openai_tools, + tool_choice: openai_tool_choice, + parallel_tool_calls, + ..Default::default() + }; + _chat_completions_req.suppress_max_tokens_if_o3(); + _chat_completions_req.fix_temperature_if_gpt5(); + Ok(_chat_completions_req) + } +} + +// Conversion from Anthropic MessagesRequest to Amazon Bedrock ConverseRequest +impl TryFrom for ConverseRequest { + type Error = TransformError; + + fn try_from(req: AnthropicMessagesRequest) -> Result { + // Convert system prompt to SystemContentBlock if present + let system: Option> = req.system.map(|system_prompt| { + let text = match system_prompt { + MessagesSystemPrompt::Single(text) => text, + MessagesSystemPrompt::Blocks(blocks) => blocks.extract_text(), + }; + vec![SystemContentBlock::Text { text }] + }); + + // Convert messages to Bedrock format + let messages = if req.messages.is_empty() { + None + } else { + let mut bedrock_messages = Vec::new(); + for anthropic_message in req.messages { + let bedrock_message: BedrockMessage = anthropic_message.try_into()?; + bedrock_messages.push(bedrock_message); + } + Some(bedrock_messages) + }; + + // Build inference configuration + // Anthropic always requires max_tokens, so we should always include inferenceConfig + let inference_config = Some(InferenceConfiguration { + max_tokens: Some(req.max_tokens), + temperature: req.temperature, + top_p: req.top_p, + stop_sequences: req.stop_sequences, + }); + + // Convert tools and tool choice to ToolConfiguration + let tool_config = if req.tools.is_some() || req.tool_choice.is_some() { + let tools = req.tools.map(|anthropic_tools| { + anthropic_tools + .into_iter() + .map(|tool| BedrockTool::ToolSpec { + tool_spec: ToolSpecDefinition { + name: tool.name, + description: tool.description, + input_schema: ToolInputSchema { + json: tool.input_schema, + }, + }, + }) + .collect() + }); + + let tool_choice = req.tool_choice.map(|choice| { + match choice.kind { + MessagesToolChoiceType::Auto => BedrockToolChoice::Auto { + auto: AutoChoice {}, + }, + MessagesToolChoiceType::Any => BedrockToolChoice::Any { any: AnyChoice {} }, + MessagesToolChoiceType::None => BedrockToolChoice::Auto { + auto: AutoChoice {}, + }, // Bedrock doesn't have explicit "none" + MessagesToolChoiceType::Tool => { + if let Some(name) = choice.name { + BedrockToolChoice::Tool { + tool: ToolChoiceSpec { name }, + } + } else { + BedrockToolChoice::Auto { + auto: AutoChoice {}, + } + } + } + } + }); + + Some(ToolConfiguration { tools, tool_choice }) + } else { + None + }; + + Ok(ConverseRequest { + model_id: req.model, + messages, + system, + inference_config, + tool_config, + stream: req.stream.unwrap_or(false), + guardrail_config: None, + additional_model_request_fields: None, + additional_model_response_field_paths: None, + performance_config: None, + prompt_variables: None, + request_metadata: None, + metadata: None, + }) + } +} + +// Message Conversions +impl TryFrom for Vec { + type Error = TransformError; + + fn try_from(message: MessagesMessage) -> Result { + let mut result = Vec::new(); + + match message.content { + MessagesMessageContent::Single(text) => { + result.push(Message { + role: message.role.into(), + content: MessageContent::Text(text), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + MessagesMessageContent::Blocks(blocks) => { + let (content_parts, tool_calls, tool_results) = blocks.split_for_openai()?; + // Add tool result messages + for (tool_use_id, result_text, _is_error) in tool_results { + result.push(Message { + role: Role::Tool, + content: MessageContent::Text(result_text), + name: None, + tool_calls: None, + tool_call_id: Some(tool_use_id), + }); + } + + // Only create main message if there's actual content or tool calls + // Skip creating empty content messages (e.g., when message only contains tool_result blocks) + if !content_parts.is_empty() || !tool_calls.is_empty() { + let content = build_openai_content(content_parts, &tool_calls); + let main_message = Message { + role: message.role.into(), + content, + name: None, + tool_calls: if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }, + tool_call_id: None, + }; + result.push(main_message); + } + } + } + + Ok(result) + } +} + +// Role Conversions +impl Into for MessagesRole { + fn into(self) -> Role { + match self { + MessagesRole::User => Role::User, + MessagesRole::Assistant => Role::Assistant, + } + } +} + +impl Into for FinishReason { + fn into(self) -> MessagesStopReason { + match self { + FinishReason::Stop => MessagesStopReason::EndTurn, + FinishReason::Length => MessagesStopReason::MaxTokens, + FinishReason::ToolCalls => MessagesStopReason::ToolUse, + FinishReason::ContentFilter => MessagesStopReason::Refusal, + FinishReason::FunctionCall => MessagesStopReason::ToolUse, + } + } +} + +impl Into for Usage { + fn into(self) -> MessagesUsage { + MessagesUsage { + input_tokens: self.prompt_tokens, + output_tokens: self.completion_tokens, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + } + } +} + +// System Prompt Conversions +impl Into for MessagesSystemPrompt { + fn into(self) -> Message { + let system_content = match self { + MessagesSystemPrompt::Single(text) => MessageContent::Text(text), + MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()), + }; + + Message { + role: Role::System, + content: system_content, + name: None, + tool_calls: None, + tool_call_id: None, + } + } +} + +//Utility Functions +/// Convert Anthropic tools to OpenAI format +fn convert_anthropic_tools(tools: Vec) -> Vec { + tools + .into_iter() + .map(|tool| Tool { + tool_type: "function".to_string(), + function: Function { + name: tool.name, + description: tool.description, + parameters: tool.input_schema, + strict: None, + }, + }) + .collect() +} + +/// Convert Anthropic tool choice to OpenAI format +fn convert_anthropic_tool_choice( + tool_choice: Option, +) -> (Option, Option) { + match tool_choice { + Some(choice) => { + let openai_choice = match choice.kind { + MessagesToolChoiceType::Auto => ToolChoice::Type(ToolChoiceType::Auto), + MessagesToolChoiceType::Any => ToolChoice::Type(ToolChoiceType::Required), + MessagesToolChoiceType::None => ToolChoice::Type(ToolChoiceType::None), + MessagesToolChoiceType::Tool => { + if let Some(name) = choice.name { + ToolChoice::Function { + choice_type: "function".to_string(), + function: FunctionChoice { name }, + } + } else { + ToolChoice::Type(ToolChoiceType::Auto) + } + } + }; + let parallel = choice.disable_parallel_tool_use.map(|disable| !disable); + (Some(openai_choice), parallel) + } + None => (None, None), + } +} + +/// Build OpenAI message content from parts and tool calls +fn build_openai_content( + content_parts: Vec, + tool_calls: &[ToolCall], +) -> MessageContent { + if content_parts.len() == 1 && tool_calls.is_empty() { + match &content_parts[0] { + ContentPart::Text { text } => MessageContent::Text(text.clone()), + _ => MessageContent::Parts(content_parts), + } + } else if content_parts.is_empty() { + MessageContent::Text("".to_string()) + } else { + MessageContent::Parts(content_parts) + } +} + +impl TryFrom for BedrockMessage { + type Error = TransformError; + + fn try_from(message: MessagesMessage) -> Result { + let role = match message.role { + MessagesRole::User => ConversationRole::User, + MessagesRole::Assistant => ConversationRole::Assistant, + }; + + let mut content_blocks = Vec::new(); + + // Convert content blocks + match message.content { + MessagesMessageContent::Single(text) => { + if !text.is_empty() { + content_blocks.push(ContentBlock::Text { text }); + } + } + MessagesMessageContent::Blocks(blocks) => { + for block in blocks { + match block { + crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => { + if !text.is_empty() { + content_blocks.push(ContentBlock::Text { text }); + } + } + crate::apis::anthropic::MessagesContentBlock::ToolUse { + id, + name, + input, + .. + } => { + content_blocks.push(ContentBlock::ToolUse { + tool_use: ToolUseBlock { + tool_use_id: id, + name, + input, + }, + }); + } + crate::apis::anthropic::MessagesContentBlock::ToolResult { + tool_use_id, + is_error, + content, + .. + } => { + // Convert Anthropic ToolResultContent to Bedrock ToolResultContentBlock + let tool_result_content = match content { + ToolResultContent::Text(text) => { + vec![ToolResultContentBlock::Text { text }] + } + ToolResultContent::Blocks(blocks) => { + let mut result_blocks = Vec::new(); + for result_block in blocks { + match result_block { + crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => { + result_blocks.push(ToolResultContentBlock::Text { text }); + } + // For now, skip other content types in tool results + _ => {} + } + } + result_blocks + } + }; + + // Ensure we have at least one content block + let final_content = if tool_result_content.is_empty() { + vec![ToolResultContentBlock::Text { + text: " ".to_string(), + }] + } else { + tool_result_content + }; + + let status = if is_error.unwrap_or(false) { + Some(ToolResultStatus::Error) + } else { + Some(ToolResultStatus::Success) + }; + + content_blocks.push(ContentBlock::ToolResult { + tool_result: ToolResultBlock { + tool_use_id, + content: final_content, + status, + }, + }); + } + crate::apis::anthropic::MessagesContentBlock::Image { source } => { + // Convert Anthropic image to Bedrock image format + match source { + crate::apis::anthropic::MessagesImageSource::Base64 { + media_type, + data, + } => { + content_blocks.push(ContentBlock::Image { + image: ImageBlock { + source: ImageSource::Base64 { media_type, data }, + }, + }); + } + crate::apis::anthropic::MessagesImageSource::Url { .. } => { + // Bedrock doesn't support URL-based images, skip for now + // Could potentially download and convert to base64, but not implemented + } + } + } + // Skip other content types for now (Thinking, Document, etc.) + _ => {} + } + } + } + } + + // Ensure we have at least one content block + if content_blocks.is_empty() { + content_blocks.push(ContentBlock::Text { + text: " ".to_string(), + }); + } + + Ok(BedrockMessage { + role, + content: content_blocks, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::apis::amazon_bedrock::{ + ContentBlock, ConversationRole, ConverseRequest, SystemContentBlock, + ToolChoice as BedrockToolChoice, + }; + use crate::apis::anthropic::{ + MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole, + MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, + }; + use serde_json::json; + + #[test] + fn test_anthropic_to_bedrock_basic_request() { + let anthropic_request = MessagesRequest { + model: "claude-3-5-sonnet-20241022".to_string(), + messages: vec![MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("Hello, how are you?".to_string()), + }], + max_tokens: 1000, + container: None, + mcp_servers: None, + system: Some(MessagesSystemPrompt::Single( + "You are a helpful assistant.".to_string(), + )), + metadata: None, + service_tier: None, + thinking: None, + temperature: Some(0.7), + top_p: Some(0.9), + top_k: None, + stream: Some(false), + stop_sequences: Some(vec!["STOP".to_string()]), + tools: None, + tool_choice: None, + }; + + let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap(); + + assert_eq!(bedrock_request.model_id, "claude-3-5-sonnet-20241022"); + assert!(bedrock_request.system.is_some()); + assert_eq!(bedrock_request.system.as_ref().unwrap().len(), 1); + assert!(bedrock_request.messages.is_some()); + let messages = bedrock_request.messages.as_ref().unwrap(); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].role, ConversationRole::User); + + if let ContentBlock::Text { text } = &messages[0].content[0] { + assert_eq!(text, "Hello, how are you?"); + } else { + panic!("Expected text content block"); + } + + let inference_config = bedrock_request.inference_config.as_ref().unwrap(); + assert_eq!(inference_config.temperature, Some(0.7)); + assert_eq!(inference_config.top_p, Some(0.9)); + assert_eq!(inference_config.max_tokens, Some(1000)); + assert_eq!( + inference_config.stop_sequences, + Some(vec!["STOP".to_string()]) + ); + } + + #[test] + fn test_anthropic_to_bedrock_with_tools() { + let anthropic_request = MessagesRequest { + model: "claude-3-5-sonnet-20241022".to_string(), + messages: vec![MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("What's the weather like?".to_string()), + }], + max_tokens: 1000, + container: None, + mcp_servers: None, + system: None, + metadata: None, + service_tier: None, + thinking: None, + temperature: None, + top_p: None, + top_k: None, + stream: None, + stop_sequences: None, + tools: Some(vec![MessagesTool { + name: "get_weather".to_string(), + description: Some("Get current weather information".to_string()), + input_schema: json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name" + } + }, + "required": ["location"] + }), + }]), + tool_choice: Some(MessagesToolChoice { + kind: MessagesToolChoiceType::Tool, + name: Some("get_weather".to_string()), + disable_parallel_tool_use: None, + }), + }; + + let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap(); + + assert_eq!(bedrock_request.model_id, "claude-3-5-sonnet-20241022"); + assert!(bedrock_request.tool_config.is_some()); + + let tool_config = bedrock_request.tool_config.as_ref().unwrap(); + assert!(tool_config.tools.is_some()); + let tools = tool_config.tools.as_ref().unwrap(); + assert_eq!(tools.len(), 1); + let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0]; + assert_eq!(tool_spec.name, "get_weather"); + assert_eq!( + tool_spec.description, + Some("Get current weather information".to_string()) + ); + + if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice { + assert_eq!(tool.name, "get_weather"); + } else { + panic!("Expected specific tool choice"); + } + } + + #[test] + fn test_anthropic_to_bedrock_auto_tool_choice() { + let anthropic_request = MessagesRequest { + model: "claude-3-5-sonnet-20241022".to_string(), + messages: vec![MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("Help me with something".to_string()), + }], + max_tokens: 500, + container: None, + mcp_servers: None, + system: None, + metadata: None, + service_tier: None, + thinking: None, + temperature: None, + top_p: None, + top_k: None, + stream: None, + stop_sequences: None, + tools: Some(vec![MessagesTool { + name: "help_tool".to_string(), + description: Some("A helpful tool".to_string()), + input_schema: json!({ + "type": "object", + "properties": {} + }), + }]), + tool_choice: Some(MessagesToolChoice { + kind: MessagesToolChoiceType::Auto, + name: None, + disable_parallel_tool_use: None, + }), + }; + + let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap(); + + assert!(bedrock_request.tool_config.is_some()); + let tool_config = bedrock_request.tool_config.as_ref().unwrap(); + assert!(matches!( + tool_config.tool_choice, + Some(BedrockToolChoice::Auto { .. }) + )); + } + + #[test] + fn test_anthropic_to_bedrock_multi_message_conversation() { + let anthropic_request = MessagesRequest { + model: "claude-3-5-sonnet-20241022".to_string(), + messages: vec![ + MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("Hello".to_string()), + }, + MessagesMessage { + role: MessagesRole::Assistant, + content: MessagesMessageContent::Single( + "Hi there! How can I help you?".to_string(), + ), + }, + MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("What's 2+2?".to_string()), + }, + ], + max_tokens: 100, + container: None, + mcp_servers: None, + system: Some(MessagesSystemPrompt::Single("Be concise".to_string())), + metadata: None, + service_tier: None, + thinking: None, + temperature: Some(0.5), + top_p: None, + top_k: None, + stream: None, + stop_sequences: None, + tools: None, + tool_choice: None, + }; + + let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap(); + + assert!(bedrock_request.messages.is_some()); + let messages = bedrock_request.messages.as_ref().unwrap(); + assert_eq!(messages.len(), 3); + assert_eq!(messages[0].role, ConversationRole::User); + assert_eq!(messages[1].role, ConversationRole::Assistant); + assert_eq!(messages[2].role, ConversationRole::User); + + // Check system prompt + assert!(bedrock_request.system.is_some()); + if let SystemContentBlock::Text { text } = &bedrock_request.system.as_ref().unwrap()[0] { + assert_eq!(text, "Be concise"); + } else { + panic!("Expected system text block"); + } + } + + #[test] + fn test_anthropic_message_to_bedrock_conversion() { + let anthropic_message = MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("Test message".to_string()), + }; + + let bedrock_message: BedrockMessage = anthropic_message.try_into().unwrap(); + + assert_eq!(bedrock_message.role, ConversationRole::User); + assert_eq!(bedrock_message.content.len(), 1); + + if let ContentBlock::Text { text } = &bedrock_message.content[0] { + assert_eq!(text, "Test message"); + } else { + panic!("Expected text content block"); + } + } +} diff --git a/crates/hermesllm/src/transforms/request/from_openai.rs b/crates/hermesllm/src/transforms/request/from_openai.rs new file mode 100644 index 00000000..df4a9557 --- /dev/null +++ b/crates/hermesllm/src/transforms/request/from_openai.rs @@ -0,0 +1,782 @@ +use crate::apis::amazon_bedrock::{ + AnyChoice, AutoChoice, ContentBlock, ConversationRole, ConverseRequest, InferenceConfiguration, + Message as BedrockMessage, SystemContentBlock, Tool as BedrockTool, + ToolChoice as BedrockToolChoice, ToolChoiceSpec, ToolConfiguration, ToolInputSchema, + ToolSpecDefinition, +}; +use crate::apis::anthropic::{ + MessagesContentBlock, MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole, + MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, + ToolResultContent, +}; +use crate::apis::openai::{ + ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType, +}; +use crate::clients::TransformError; +use crate::transforms::lib::ExtractText; +use crate::transforms::lib::*; +use crate::transforms::*; + +type AnthropicMessagesRequest = MessagesRequest; + +// ============================================================================ +// MAIN REQUEST TRANSFORMATIONS +// ============================================================================ + +impl Into for Message { + fn into(self) -> MessagesSystemPrompt { + let system_text = match self.content { + MessageContent::Text(text) => text, + MessageContent::Parts(parts) => parts.extract_text(), + }; + MessagesSystemPrompt::Single(system_text) + } +} + +impl TryFrom for MessagesMessage { + type Error = TransformError; + + fn try_from(message: Message) -> Result { + let role = match message.role { + Role::User => MessagesRole::User, + Role::Assistant => MessagesRole::Assistant, + Role::Tool => { + // Tool messages become user messages with tool results + let tool_call_id = message.tool_call_id.ok_or_else(|| { + TransformError::MissingField( + "tool_call_id required for Tool messages".to_string(), + ) + })?; + + return Ok(MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Blocks(vec![ + MessagesContentBlock::ToolResult { + tool_use_id: tool_call_id, + is_error: None, + content: ToolResultContent::Blocks(vec![MessagesContentBlock::Text { + text: message.content.extract_text(), + cache_control: None, + }]), + cache_control: None, + }, + ]), + }); + } + Role::System => { + return Err(TransformError::UnsupportedConversion( + "System messages should be handled separately".to_string(), + )); + } + }; + + let content_blocks = convert_openai_message_to_anthropic_content(&message)?; + let content = build_anthropic_content(content_blocks); + + Ok(MessagesMessage { role, content }) + } +} + +impl TryFrom for BedrockMessage { + type Error = TransformError; + + fn try_from(message: Message) -> Result { + let role = match message.role { + Role::User => ConversationRole::User, + Role::Assistant => ConversationRole::Assistant, + Role::Tool => ConversationRole::User, // Tool results become user messages in Bedrock + Role::System => { + return Err(TransformError::UnsupportedConversion( + "System messages should be handled separately in Bedrock".to_string(), + )); + } + }; + + let mut content_blocks = Vec::new(); + + // Handle different message types + match message.role { + Role::User => { + // Convert user message content to content blocks + match message.content { + MessageContent::Text(text) => { + if !text.is_empty() { + content_blocks.push(ContentBlock::Text { text }); + } + } + MessageContent::Parts(parts) => { + // Convert OpenAI content parts to Bedrock ContentBlocks + for part in parts { + match part { + crate::apis::openai::ContentPart::Text { text } => { + if !text.is_empty() { + content_blocks.push(ContentBlock::Text { text }); + } + } + crate::apis::openai::ContentPart::ImageUrl { image_url } => { + // Convert image URL to Bedrock image format + if image_url.url.starts_with("data:") { + if let Some((media_type, data)) = + parse_data_url(&image_url.url) + { + content_blocks.push(ContentBlock::Image { + image: crate::apis::amazon_bedrock::ImageBlock { + source: crate::apis::amazon_bedrock::ImageSource::Base64 { + media_type, + data, + }, + }, + }); + } else { + return Err(TransformError::UnsupportedConversion( + format!( + "Invalid data URL format: {}", + image_url.url + ), + )); + } + } else { + return Err(TransformError::UnsupportedConversion( + "Only base64 data URLs are supported for images in Bedrock".to_string() + )); + } + } + } + } + } + } + + // Ensure we have at least one content block + if content_blocks.is_empty() { + content_blocks.push(ContentBlock::Text { + text: " ".to_string(), + }); + } + } + Role::Assistant => { + // Handle text content - but only add if non-empty OR if we don't have tool calls + let text_content = message.content.extract_text(); + let has_tool_calls = message + .tool_calls + .as_ref() + .map_or(false, |calls| !calls.is_empty()); + + // Add text content if it's non-empty, or if we have no tool calls (to avoid empty content) + if !text_content.is_empty() { + content_blocks.push(ContentBlock::Text { text: text_content }); + } else if !has_tool_calls { + // If we have empty content and no tool calls, add a minimal placeholder + // This prevents the "blank text field" error + content_blocks.push(ContentBlock::Text { + text: " ".to_string(), + }); + } + + // Convert tool calls to ToolUse content blocks + if let Some(tool_calls) = message.tool_calls { + for tool_call in tool_calls { + // Parse the arguments string as JSON + let input: serde_json::Value = + serde_json::from_str(&tool_call.function.arguments).map_err(|e| { + TransformError::UnsupportedConversion(format!( + "Failed to parse tool arguments as JSON: {}. Arguments: {}", + e, tool_call.function.arguments + )) + })?; + + content_blocks.push(ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: tool_call.id, + name: tool_call.function.name, + input, + }, + }); + } + } + + // Bedrock requires at least one content block + if content_blocks.is_empty() { + content_blocks.push(ContentBlock::Text { + text: " ".to_string(), + }); + } + } + Role::Tool => { + // Tool messages become user messages with ToolResult content blocks + let tool_call_id = message.tool_call_id.ok_or_else(|| { + TransformError::MissingField( + "tool_call_id required for Tool messages".to_string(), + ) + })?; + + let tool_content = message.content.extract_text(); + + // Create ToolResult content block + let tool_result_content = if tool_content.is_empty() { + // Even for tool results, we need non-empty content + vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text { + text: " ".to_string(), + }] + } else { + vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text { + text: tool_content, + }] + }; + + content_blocks.push(ContentBlock::ToolResult { + tool_result: crate::apis::amazon_bedrock::ToolResultBlock { + tool_use_id: tool_call_id, + content: tool_result_content, + status: Some(crate::apis::amazon_bedrock::ToolResultStatus::Success), // Default to success + }, + }); + } + Role::System => { + // Already handled above with early return + unreachable!() + } + } + + Ok(BedrockMessage { + role, + content: content_blocks, + }) + } +} + +impl TryFrom for AnthropicMessagesRequest { + type Error = TransformError; + + fn try_from(req: ChatCompletionsRequest) -> Result { + let mut system_prompt = None; + let mut messages = Vec::new(); + + for message in req.messages { + match message.role { + Role::System => { + system_prompt = Some(message.into()); + } + _ => { + let anthropic_message: MessagesMessage = message.try_into()?; + messages.push(anthropic_message); + } + } + } + + // Convert tools and tool choice + let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools)); + let anthropic_tool_choice = + convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls); + + Ok(AnthropicMessagesRequest { + model: req.model, + system: system_prompt, + messages, + max_tokens: req + .max_completion_tokens + .or(req.max_tokens) + .unwrap_or(DEFAULT_MAX_TOKENS), + container: None, + mcp_servers: None, + service_tier: None, + thinking: None, + temperature: req.temperature, + top_p: req.top_p, + top_k: None, // OpenAI doesn't have top_k + stream: req.stream, + stop_sequences: req.stop, + tools: anthropic_tools, + tool_choice: anthropic_tool_choice, + metadata: None, + }) + } +} + +impl TryFrom for ConverseRequest { + type Error = TransformError; + + fn try_from(req: ChatCompletionsRequest) -> Result { + // Separate system messages from user/assistant messages + let mut system_messages = Vec::new(); + let mut conversation_messages = Vec::new(); + + for message in req.messages { + match message.role { + Role::System => { + let system_text = match message.content { + MessageContent::Text(text) => text, + MessageContent::Parts(parts) => parts.extract_text(), + }; + system_messages.push(SystemContentBlock::Text { text: system_text }); + } + _ => { + let bedrock_message: BedrockMessage = message.try_into()?; + conversation_messages.push(bedrock_message); + } + } + } + + // Convert system messages + let system = if system_messages.is_empty() { + None + } else { + Some(system_messages) + }; + + // Convert conversation messages + let messages = if conversation_messages.is_empty() { + None + } else { + Some(conversation_messages) + }; + + // Build inference configuration + let max_tokens = req.max_completion_tokens.or(req.max_tokens); + let inference_config = if max_tokens.is_some() + || req.temperature.is_some() + || req.top_p.is_some() + || req.stop.is_some() + { + Some(InferenceConfiguration { + max_tokens, + temperature: req.temperature, + top_p: req.top_p, + stop_sequences: req.stop, + }) + } else { + None + }; + + // Convert tools and tool choice to ToolConfiguration + let tool_config = if req.tools.is_some() || req.tool_choice.is_some() { + let tools = req.tools.map(|openai_tools| { + openai_tools + .into_iter() + .map(|tool| BedrockTool::ToolSpec { + tool_spec: ToolSpecDefinition { + name: tool.function.name, + description: tool.function.description, + input_schema: ToolInputSchema { + json: tool.function.parameters, + }, + }, + }) + .collect() + }); + + let tool_choice = req + .tool_choice + .map(|choice| { + match choice { + ToolChoice::Type(tool_type) => match tool_type { + ToolChoiceType::Auto => BedrockToolChoice::Auto { + auto: AutoChoice {}, + }, + ToolChoiceType::Required => { + BedrockToolChoice::Any { any: AnyChoice {} } + } + ToolChoiceType::None => BedrockToolChoice::Auto { + auto: AutoChoice {}, + }, // Bedrock doesn't have explicit "none" + }, + ToolChoice::Function { function, .. } => BedrockToolChoice::Tool { + tool: ToolChoiceSpec { + name: function.name, + }, + }, + } + }) + .or_else(|| { + // If tools are present but no tool_choice specified, default to "auto" + if tools.is_some() { + Some(BedrockToolChoice::Auto { + auto: AutoChoice {}, + }) + } else { + None + } + }); + + Some(ToolConfiguration { tools, tool_choice }) + } else { + None + }; + + Ok(ConverseRequest { + model_id: req.model, + messages, + system, + inference_config, + tool_config, + stream: req.stream.unwrap_or(false), + guardrail_config: None, + additional_model_request_fields: None, + additional_model_response_field_paths: None, + performance_config: None, + prompt_variables: None, + request_metadata: None, + metadata: None, + }) + } +} + +/// Convert OpenAI tools to Anthropic format +fn convert_openai_tools(tools: Vec) -> Vec { + tools + .into_iter() + .map(|tool| MessagesTool { + name: tool.function.name, + description: tool.function.description, + input_schema: tool.function.parameters, + }) + .collect() +} + +/// Convert OpenAI tool choice to Anthropic format +fn convert_openai_tool_choice( + tool_choice: Option, + parallel_tool_calls: Option, +) -> Option { + tool_choice.map(|choice| match choice { + ToolChoice::Type(tool_type) => match tool_type { + ToolChoiceType::Auto => MessagesToolChoice { + kind: MessagesToolChoiceType::Auto, + name: None, + disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), + }, + ToolChoiceType::Required => MessagesToolChoice { + kind: MessagesToolChoiceType::Any, + name: None, + disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), + }, + ToolChoiceType::None => MessagesToolChoice { + kind: MessagesToolChoiceType::None, + name: None, + disable_parallel_tool_use: None, + }, + }, + ToolChoice::Function { function, .. } => MessagesToolChoice { + kind: MessagesToolChoiceType::Tool, + name: Some(function.name), + disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), + }, + }) +} + +/// Build Anthropic message content from content blocks +fn build_anthropic_content(content_blocks: Vec) -> MessagesMessageContent { + if content_blocks.len() == 1 { + match &content_blocks[0] { + MessagesContentBlock::Text { text, .. } => MessagesMessageContent::Single(text.clone()), + _ => MessagesMessageContent::Blocks(content_blocks), + } + } else if content_blocks.is_empty() { + MessagesMessageContent::Single("".to_string()) + } else { + MessagesMessageContent::Blocks(content_blocks) + } +} + +/// Parse a data URL into media type and base64 data +/// Supports format: data:image/jpeg;base64, +fn parse_data_url(url: &str) -> Option<(String, String)> { + if !url.starts_with("data:") { + return None; + } + + let without_prefix = &url[5..]; // Remove "data:" prefix + let parts: Vec<&str> = without_prefix.splitn(2, ',').collect(); + + if parts.len() != 2 { + return None; + } + + let header = parts[0]; + let data = parts[1]; + + // Parse header: "image/jpeg;base64" or just "image/jpeg" + let header_parts: Vec<&str> = header.split(';').collect(); + if header_parts.is_empty() { + return None; + } + + let media_type = header_parts[0].to_string(); + + // Check if it's base64 encoded + if header_parts.len() > 1 && header_parts[1] == "base64" { + Some((media_type, data.to_string())) + } else { + // For now, only support base64 encoding + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::apis::amazon_bedrock::{ + ContentBlock, ConversationRole, ConverseRequest, SystemContentBlock, + ToolChoice as BedrockToolChoice, + }; + use crate::apis::openai::{ + ChatCompletionsRequest, Function, FunctionChoice, Message, MessageContent, Role, Tool, + ToolChoice, ToolChoiceType, + }; + use serde_json::json; + + #[test] + fn test_openai_to_bedrock_basic_request() { + let openai_request = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![ + Message { + role: Role::System, + content: MessageContent::Text("You are a helpful assistant.".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }, + Message { + role: Role::User, + content: MessageContent::Text("Hello, how are you?".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }, + ], + temperature: Some(0.7), + top_p: Some(0.9), + max_completion_tokens: Some(1000), + stop: Some(vec!["STOP".to_string()]), + stream: Some(false), + tools: None, + tool_choice: None, + ..Default::default() + }; + + let bedrock_request: ConverseRequest = openai_request.try_into().unwrap(); + + assert_eq!(bedrock_request.model_id, "gpt-4"); + assert!(bedrock_request.system.is_some()); + assert_eq!(bedrock_request.system.as_ref().unwrap().len(), 1); + + if let SystemContentBlock::Text { text } = &bedrock_request.system.as_ref().unwrap()[0] { + assert_eq!(text, "You are a helpful assistant."); + } else { + panic!("Expected system text block"); + } + + assert!(bedrock_request.messages.is_some()); + let messages = bedrock_request.messages.as_ref().unwrap(); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].role, ConversationRole::User); + + if let ContentBlock::Text { text } = &messages[0].content[0] { + assert_eq!(text, "Hello, how are you?"); + } else { + panic!("Expected text content block"); + } + + let inference_config = bedrock_request.inference_config.as_ref().unwrap(); + assert_eq!(inference_config.temperature, Some(0.7)); + assert_eq!(inference_config.top_p, Some(0.9)); + assert_eq!(inference_config.max_tokens, Some(1000)); + assert_eq!( + inference_config.stop_sequences, + Some(vec!["STOP".to_string()]) + ); + } + + #[test] + fn test_openai_to_bedrock_with_tools() { + let openai_request = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![Message { + role: Role::User, + content: MessageContent::Text("What's the weather like?".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }], + temperature: None, + top_p: None, + max_completion_tokens: Some(1000), + stop: None, + stream: None, + tools: Some(vec![Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get current weather information".to_string()), + parameters: json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name" + } + }, + "required": ["location"] + }), + strict: None, + }, + }]), + tool_choice: Some(ToolChoice::Function { + choice_type: "function".to_string(), + function: FunctionChoice { + name: "get_weather".to_string(), + }, + }), + ..Default::default() + }; + + let bedrock_request: ConverseRequest = openai_request.try_into().unwrap(); + + assert_eq!(bedrock_request.model_id, "gpt-4"); + assert!(bedrock_request.tool_config.is_some()); + + let tool_config = bedrock_request.tool_config.as_ref().unwrap(); + assert!(tool_config.tools.is_some()); + let tools = tool_config.tools.as_ref().unwrap(); + assert_eq!(tools.len(), 1); + + let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0]; + assert_eq!(tool_spec.name, "get_weather"); + assert_eq!( + tool_spec.description, + Some("Get current weather information".to_string()) + ); + + if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice { + assert_eq!(tool.name, "get_weather"); + } else { + panic!("Expected specific tool choice"); + } + } + + #[test] + fn test_openai_to_bedrock_auto_tool_choice() { + let openai_request = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![Message { + role: Role::User, + content: MessageContent::Text("Help me with something".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }], + temperature: None, + top_p: None, + max_completion_tokens: Some(500), + stop: None, + stream: None, + tools: Some(vec![Tool { + tool_type: "function".to_string(), + function: Function { + name: "help_tool".to_string(), + description: Some("A helpful tool".to_string()), + parameters: json!({ + "type": "object", + "properties": {} + }), + strict: None, + }, + }]), + tool_choice: Some(ToolChoice::Type(ToolChoiceType::Auto)), + ..Default::default() + }; + + let bedrock_request: ConverseRequest = openai_request.try_into().unwrap(); + + assert!(bedrock_request.tool_config.is_some()); + let tool_config = bedrock_request.tool_config.as_ref().unwrap(); + assert!(matches!( + tool_config.tool_choice, + Some(BedrockToolChoice::Auto { .. }) + )); + } + + #[test] + fn test_openai_to_bedrock_multi_message_conversation() { + let openai_request = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![ + Message { + role: Role::System, + content: MessageContent::Text("Be concise".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }, + Message { + role: Role::User, + content: MessageContent::Text("Hello".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }, + Message { + role: Role::Assistant, + content: MessageContent::Text("Hi there! How can I help you?".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }, + Message { + role: Role::User, + content: MessageContent::Text("What's 2+2?".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }, + ], + temperature: Some(0.5), + top_p: None, + max_completion_tokens: Some(100), + stop: None, + stream: None, + tools: None, + tool_choice: None, + ..Default::default() + }; + + let bedrock_request: ConverseRequest = openai_request.try_into().unwrap(); + + assert!(bedrock_request.messages.is_some()); + let messages = bedrock_request.messages.as_ref().unwrap(); + assert_eq!(messages.len(), 3); // System message is separate + assert_eq!(messages[0].role, ConversationRole::User); + assert_eq!(messages[1].role, ConversationRole::Assistant); + assert_eq!(messages[2].role, ConversationRole::User); + + // Check system prompt + assert!(bedrock_request.system.is_some()); + if let SystemContentBlock::Text { text } = &bedrock_request.system.as_ref().unwrap()[0] { + assert_eq!(text, "Be concise"); + } else { + panic!("Expected system text block"); + } + } + + #[test] + fn test_openai_message_to_bedrock_conversion() { + let openai_message = Message { + role: Role::User, + content: MessageContent::Text("Test message".to_string()), + name: None, + tool_call_id: None, + tool_calls: None, + }; + + let bedrock_message: BedrockMessage = openai_message.try_into().unwrap(); + + assert_eq!(bedrock_message.role, ConversationRole::User); + assert_eq!(bedrock_message.content.len(), 1); + + if let ContentBlock::Text { text } = &bedrock_message.content[0] { + assert_eq!(text, "Test message"); + } else { + panic!("Expected text content block"); + } + } +} diff --git a/crates/hermesllm/src/transforms/request/mod.rs b/crates/hermesllm/src/transforms/request/mod.rs new file mode 100644 index 00000000..5fbdf0b1 --- /dev/null +++ b/crates/hermesllm/src/transforms/request/mod.rs @@ -0,0 +1,4 @@ +//! Request transformation modules + +pub mod from_anthropic; +pub mod from_openai; diff --git a/crates/hermesllm/src/transforms/response/mod.rs b/crates/hermesllm/src/transforms/response/mod.rs new file mode 100644 index 00000000..3ce75123 --- /dev/null +++ b/crates/hermesllm/src/transforms/response/mod.rs @@ -0,0 +1,3 @@ +//! Response transformation modules +pub mod to_anthropic; +pub mod to_openai; diff --git a/crates/hermesllm/src/transforms/response/to_anthropic.rs b/crates/hermesllm/src/transforms/response/to_anthropic.rs new file mode 100644 index 00000000..1c6ce238 --- /dev/null +++ b/crates/hermesllm/src/transforms/response/to_anthropic.rs @@ -0,0 +1,1051 @@ +use crate::apis::amazon_bedrock::{ + ContentBlockDelta, ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason, +}; +use crate::apis::anthropic::{ + MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesResponse, + MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage, +}; +use crate::apis::openai::{ + ChatCompletionsResponse, ChatCompletionsStreamResponse, Role, ToolCallDelta, +}; +use crate::clients::TransformError; +use crate::transforms::lib::*; +use serde_json::Value; + +// ============================================================================ +// STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for convenience +// ============================================================================ + +impl TryFrom for MessagesResponse { + type Error = TransformError; + + fn try_from(resp: ChatCompletionsResponse) -> Result { + let choice = resp + .choices + .into_iter() + .next() + .ok_or_else(|| TransformError::MissingField("choices".to_string()))?; + + let content = convert_openai_message_to_anthropic_content(&choice.message.to_message())?; + let stop_reason = choice + .finish_reason + .map(|fr| fr.into()) + .unwrap_or(MessagesStopReason::EndTurn); + + let usage = MessagesUsage { + input_tokens: resp.usage.prompt_tokens, + output_tokens: resp.usage.completion_tokens, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }; + + Ok(MessagesResponse { + id: resp.id, + obj_type: "message".to_string(), + role: MessagesRole::Assistant, + content, + model: resp.model, + stop_reason, + stop_sequence: None, + usage, + container: None, + }) + } +} + +impl TryFrom for MessagesResponse { + type Error = TransformError; + + fn try_from(resp: ConverseResponse) -> Result { + // Extract the message from the ConverseOutput + let message = match resp.output { + ConverseOutput::Message { message } => message, + }; + + // Convert Bedrock message content to Anthropic content blocks + let content = convert_bedrock_message_to_anthropic_content(&message)?; + + // Convert Bedrock ConversationRole to Anthropic MessagesRole + let role = match message.role { + crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => MessagesRole::Assistant, + }; + + // Convert Bedrock stop reason to Anthropic stop reason + let stop_reason = match resp.stop_reason { + StopReason::EndTurn => MessagesStopReason::EndTurn, + StopReason::ToolUse => MessagesStopReason::ToolUse, + StopReason::MaxTokens => MessagesStopReason::MaxTokens, + StopReason::StopSequence => MessagesStopReason::EndTurn, + StopReason::GuardrailIntervened => MessagesStopReason::Refusal, + StopReason::ContentFiltered => MessagesStopReason::Refusal, + }; + + // Convert token usage + let usage = MessagesUsage { + input_tokens: resp.usage.input_tokens, + output_tokens: resp.usage.output_tokens, + cache_creation_input_tokens: resp.usage.cache_write_input_tokens, + cache_read_input_tokens: resp.usage.cache_read_input_tokens, + }; + + // Generate a response ID (Bedrock doesn't provide one) + let id = format!( + "bedrock-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() + ); + + // Extract model ID from trace information if available, otherwise use fallback + let model = resp + .trace + .as_ref() + .and_then(|trace| trace.prompt_router.as_ref()) + .map(|router| router.invoked_model_id.clone()) + .unwrap_or_else(|| "bedrock-model".to_string()); + + Ok(MessagesResponse { + id, + obj_type: "message".to_string(), + role, + content, + model, + stop_reason, + stop_sequence: None, // TODO: Could extract from additional_model_response_fields if needed + usage, + container: None, + }) + } +} + +impl TryFrom for MessagesStreamEvent { + type Error = TransformError; + + fn try_from(resp: ChatCompletionsStreamResponse) -> Result { + if resp.choices.is_empty() { + return Ok(MessagesStreamEvent::Ping); + } + + let choice = &resp.choices[0]; + + // Handle final chunk with usage + let has_usage = resp.usage.is_some(); + if let Some(usage) = resp.usage { + if let Some(finish_reason) = &choice.finish_reason { + let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); + return Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: anthropic_stop_reason, + stop_sequence: None, + }, + usage: usage.into(), + }); + } + } + + // Handle role start + if let Some(Role::Assistant) = choice.delta.role { + return Ok(MessagesStreamEvent::MessageStart { + message: MessagesStreamMessage { + id: resp.id, + obj_type: "message".to_string(), + role: MessagesRole::Assistant, + content: vec![], + model: resp.model, + stop_reason: None, + stop_sequence: None, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }, + }); + } + + // Handle content delta + if let Some(content) = &choice.delta.content { + if !content.is_empty() { + return Ok(MessagesStreamEvent::ContentBlockDelta { + index: 0, + delta: MessagesContentDelta::TextDelta { + text: content.clone(), + }, + }); + } + } + + // Handle tool calls + if let Some(tool_calls) = &choice.delta.tool_calls { + return convert_tool_call_deltas(tool_calls.clone()); + } + + // Handle finish reason - generate MessageDelta only (MessageStop comes later) + if let Some(finish_reason) = &choice.finish_reason { + // If we have usage data, it was already handled above + // If not, we need to generate MessageDelta with default usage + if !has_usage { + let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); + return Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: anthropic_stop_reason, + stop_sequence: None, + }, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }); + } + // If usage was already handled above, we don't need to do anything more here + // MessageStop will be handled when [DONE] is encountered + } + + // Default to ping for unhandled cases + Ok(MessagesStreamEvent::Ping) + } +} + +impl Into for MessagesStreamEvent { + fn into(self) -> String { + let transformed_json = serde_json::to_string(&self).unwrap_or_default(); + let event_type = match &self { + MessagesStreamEvent::MessageStart { .. } => "message_start", + MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start", + MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta", + MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop", + MessagesStreamEvent::MessageDelta { .. } => "message_delta", + MessagesStreamEvent::MessageStop => "message_stop", + MessagesStreamEvent::Ping => "ping", + }; + + let event = format!("event: {}\n", event_type); + let data = format!("data: {}\n\n", transformed_json); + event + &data + } +} + +impl TryFrom for MessagesStreamEvent { + type Error = TransformError; + + fn try_from(event: ConverseStreamEvent) -> Result { + match event { + // MessageStart - convert to Anthropic MessageStart + ConverseStreamEvent::MessageStart(start_event) => { + let role = match start_event.role { + crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => { + MessagesRole::Assistant + } + }; + + Ok(MessagesStreamEvent::MessageStart { + message: MessagesStreamMessage { + id: format!( + "bedrock-stream-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() + ), + obj_type: "message".to_string(), + role, + content: vec![], + model: "bedrock-model".to_string(), + stop_reason: None, + stop_sequence: None, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }, + }) + } + + // ContentBlockStart - convert to Anthropic ContentBlockStart + ConverseStreamEvent::ContentBlockStart(start_event) => { + // Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas + // Anthropic expects the same pattern, so we initialize with an empty input object + match start_event.start { + crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => { + Ok(MessagesStreamEvent::ContentBlockStart { + index: start_event.content_block_index as u32, + content_block: MessagesContentBlock::ToolUse { + id: tool_use.tool_use_id, + name: tool_use.name, + input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas + cache_control: None, + }, + }) + } + } + } + + // ContentBlockDelta - convert to Anthropic ContentBlockDelta + ConverseStreamEvent::ContentBlockDelta(delta_event) => { + let delta = match delta_event.delta { + ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text }, + ContentBlockDelta::ToolUse { tool_use } => { + MessagesContentDelta::InputJsonDelta { + partial_json: tool_use.input, + } + } + }; + + Ok(MessagesStreamEvent::ContentBlockDelta { + index: delta_event.content_block_index as u32, + delta, + }) + } + + // ContentBlockStop - convert to Anthropic ContentBlockStop + ConverseStreamEvent::ContentBlockStop(stop_event) => { + Ok(MessagesStreamEvent::ContentBlockStop { + index: stop_event.content_block_index as u32, + }) + } + + // MessageStop - convert to Anthropic MessageDelta with stop reason + MessageStop + ConverseStreamEvent::MessageStop(stop_event) => { + let anthropic_stop_reason = match stop_event.stop_reason { + StopReason::EndTurn => MessagesStopReason::EndTurn, + StopReason::ToolUse => MessagesStopReason::ToolUse, + StopReason::MaxTokens => MessagesStopReason::MaxTokens, + StopReason::StopSequence => MessagesStopReason::EndTurn, + StopReason::GuardrailIntervened => MessagesStopReason::Refusal, + StopReason::ContentFiltered => MessagesStopReason::Refusal, + }; + + // Return MessageDelta (MessageStop will be sent separately by the streaming handler) + Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: anthropic_stop_reason, + stop_sequence: None, + }, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }) + } + + // Metadata - convert usage information to MessageDelta + ConverseStreamEvent::Metadata(metadata_event) => { + Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: MessagesStopReason::EndTurn, + stop_sequence: None, + }, + usage: MessagesUsage { + input_tokens: metadata_event.usage.input_tokens, + output_tokens: metadata_event.usage.output_tokens, + cache_creation_input_tokens: metadata_event.usage.cache_write_input_tokens, + cache_read_input_tokens: metadata_event.usage.cache_read_input_tokens, + }, + }) + } + + // Exception events - convert to Ping (could be enhanced to return error events) + ConverseStreamEvent::InternalServerException(_) + | ConverseStreamEvent::ModelStreamErrorException(_) + | ConverseStreamEvent::ServiceUnavailableException(_) + | ConverseStreamEvent::ThrottlingException(_) + | ConverseStreamEvent::ValidationException(_) => { + // TODO: Consider adding proper error handling/events + Ok(MessagesStreamEvent::Ping) + } + } + } +} + +/// Convert tool call deltas to Anthropic stream events +fn convert_tool_call_deltas( + tool_calls: Vec, +) -> Result { + for tool_call in tool_calls { + if let Some(id) = &tool_call.id { + // Tool call start + if let Some(function) = &tool_call.function { + if let Some(name) = &function.name { + return Ok(MessagesStreamEvent::ContentBlockStart { + index: tool_call.index, + content_block: MessagesContentBlock::ToolUse { + id: id.clone(), + name: name.clone(), + input: Value::Object(serde_json::Map::new()), + cache_control: None, + }, + }); + } + } + } else if let Some(function) = &tool_call.function { + if let Some(arguments) = &function.arguments { + // Tool arguments delta + return Ok(MessagesStreamEvent::ContentBlockDelta { + index: tool_call.index, + delta: MessagesContentDelta::InputJsonDelta { + partial_json: arguments.clone(), + }, + }); + } + } + } + + // Fallback to ping if no valid tool call found + Ok(MessagesStreamEvent::Ping) +} + +/// Convert Bedrock Message to Anthropic content blocks +/// +/// This function handles the conversion between Amazon Bedrock Converse API format +/// and Anthropic's Messages API format. Key differences handled: +/// +/// 1. **Image/Document Sources**: Bedrock supports base64 and S3 locations, while +/// Anthropic supports base64, URLs, and file IDs. Currently only base64 is supported. +/// 2. **Tool Result Status**: Bedrock uses enum status (Success/Error), Anthropic uses +/// boolean is_error field. +/// 3. **Document Names**: Bedrock includes optional document names, Anthropic doesn't. +/// 4. **JSON Content**: Bedrock has native JSON content blocks, converted to text for Anthropic. +/// +/// Note on S3/URL handling: Converting S3 locations or URLs would require async operations +/// to download and convert to base64, which is not implemented in this synchronous function. +fn convert_bedrock_message_to_anthropic_content( + message: &crate::apis::amazon_bedrock::Message, +) -> Result, TransformError> { + use crate::apis::amazon_bedrock::ContentBlock; + + let mut content_blocks = Vec::new(); + + for content_block in &message.content { + match content_block { + ContentBlock::Text { text } => { + content_blocks.push(MessagesContentBlock::Text { + text: text.clone(), + cache_control: None, + }); + } + ContentBlock::ToolUse { tool_use } => { + content_blocks.push(MessagesContentBlock::ToolUse { + id: tool_use.tool_use_id.clone(), + name: tool_use.name.clone(), + input: tool_use.input.clone(), + cache_control: None, + }); + } + ContentBlock::ToolResult { tool_result } => { + // Convert tool result content blocks + let mut tool_result_blocks = Vec::new(); + for result_content in &tool_result.content { + match result_content { + crate::apis::amazon_bedrock::ToolResultContentBlock::Text { text } => { + tool_result_blocks.push(MessagesContentBlock::Text { + text: text.clone(), + cache_control: None, + }); + } + crate::apis::amazon_bedrock::ToolResultContentBlock::Image { source } => { + // Convert Bedrock ImageSource to Anthropic format + match source { + crate::apis::amazon_bedrock::ImageSource::Base64 { + media_type, + data, + } => { + tool_result_blocks.push(MessagesContentBlock::Image { + source: + crate::apis::anthropic::MessagesImageSource::Base64 { + media_type: media_type.clone(), + data: data.clone(), + }, + }); + } // Note: S3Location is not yet implemented in the current Bedrock API definition + // but would need async handling when added + } + } + crate::apis::amazon_bedrock::ToolResultContentBlock::Json { json } => { + // Convert JSON content to text representation + tool_result_blocks.push(MessagesContentBlock::Text { + text: serde_json::to_string(&json).unwrap_or_default(), + cache_control: None, + }); + } + } + } + + use crate::apis::anthropic::ToolResultContent; + content_blocks.push(MessagesContentBlock::ToolResult { + tool_use_id: tool_result.tool_use_id.clone(), + is_error: tool_result + .status + .as_ref() + .map(|s| matches!(s, crate::apis::amazon_bedrock::ToolResultStatus::Error)), + content: ToolResultContent::Blocks(tool_result_blocks), + cache_control: None, + }); + } + ContentBlock::Image { image } => { + // Convert Bedrock ImageSource to Anthropic format + match &image.source { + crate::apis::amazon_bedrock::ImageSource::Base64 { media_type, data } => { + content_blocks.push(MessagesContentBlock::Image { + source: crate::apis::anthropic::MessagesImageSource::Base64 { + media_type: media_type.clone(), + data: data.clone(), + }, + }); + } // Note: S3Location would require async handling if implemented + } + } + ContentBlock::Document { document } => { + // Convert Bedrock DocumentSource to Anthropic format + // Note: Bedrock's 'name' field is lost in conversion as Anthropic doesn't support it + match &document.source { + crate::apis::amazon_bedrock::DocumentSource::Base64 { media_type, data } => { + content_blocks.push(MessagesContentBlock::Document { + source: crate::apis::anthropic::MessagesDocumentSource::Base64 { + media_type: media_type.clone(), + data: data.clone(), + }, + }); + } // Note: S3Location would require async handling if implemented + } + } + ContentBlock::GuardContent { guard_content } => { + // Convert guard content to text block + if let Some(guard_text) = &guard_content.text { + content_blocks.push(MessagesContentBlock::Text { + text: guard_text.text.clone(), + cache_control: None, + }); + } + } + } + } + + Ok(content_blocks) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::apis::amazon_bedrock::{ + BedrockTokenUsage, ContentBlock, ConversationRole, ConverseOutput, ConverseResponse, + ConverseTrace, Message as BedrockMessage, PromptRouterTrace, StopReason, + ToolResultContentBlock, ToolResultStatus, + }; + use crate::apis::anthropic::{ + MessagesContentBlock, MessagesResponse, MessagesRole, MessagesStopReason, ToolResultContent, + }; + use serde_json::json; + + #[test] + fn test_bedrock_to_anthropic_basic_response() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "Hello! How can I help you today?".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 25, + total_tokens: 35, + cache_write_input_tokens: None, + cache_read_input_tokens: None, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(anthropic_response.obj_type, "message"); + assert_eq!(anthropic_response.role, MessagesRole::Assistant); + assert_eq!(anthropic_response.model, "bedrock-model"); + assert_eq!(anthropic_response.stop_reason, MessagesStopReason::EndTurn); + assert!(anthropic_response.id.starts_with("bedrock-")); + + // Check content + assert_eq!(anthropic_response.content.len(), 1); + if let MessagesContentBlock::Text { text, .. } = &anthropic_response.content[0] { + assert_eq!(text, "Hello! How can I help you today?"); + } else { + panic!("Expected text content block"); + } + + // Check usage + assert_eq!(anthropic_response.usage.input_tokens, 10); + assert_eq!(anthropic_response.usage.output_tokens, 25); + assert_eq!(anthropic_response.usage.cache_creation_input_tokens, None); + assert_eq!(anthropic_response.usage.cache_read_input_tokens, None); + } + + #[test] + fn test_bedrock_to_anthropic_with_tool_use() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "I'll help you check the weather.".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_12345".to_string(), + name: "get_weather".to_string(), + input: json!({ + "location": "San Francisco" + }), + }, + }, + ], + }, + }, + stop_reason: StopReason::ToolUse, + usage: BedrockTokenUsage { + input_tokens: 15, + output_tokens: 30, + total_tokens: 45, + cache_write_input_tokens: None, + cache_read_input_tokens: None, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(anthropic_response.stop_reason, MessagesStopReason::ToolUse); + assert_eq!(anthropic_response.content.len(), 2); + + // Check text content + if let MessagesContentBlock::Text { text, .. } = &anthropic_response.content[0] { + assert_eq!(text, "I'll help you check the weather."); + } else { + panic!("Expected text content block"); + } + + // Check tool use content + if let MessagesContentBlock::ToolUse { + id, name, input, .. + } = &anthropic_response.content[1] + { + assert_eq!(id, "tool_12345"); + assert_eq!(name, "get_weather"); + assert_eq!(input["location"], "San Francisco"); + } else { + panic!("Expected tool use content block"); + } + } + + #[test] + fn test_bedrock_to_anthropic_stop_reason_conversions() { + let test_cases = vec![ + (StopReason::EndTurn, MessagesStopReason::EndTurn), + (StopReason::ToolUse, MessagesStopReason::ToolUse), + (StopReason::MaxTokens, MessagesStopReason::MaxTokens), + (StopReason::StopSequence, MessagesStopReason::EndTurn), + (StopReason::GuardrailIntervened, MessagesStopReason::Refusal), + (StopReason::ContentFiltered, MessagesStopReason::Refusal), + ]; + + for (bedrock_stop_reason, expected_anthropic_stop_reason) in test_cases { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "Test response".to_string(), + }], + }, + }, + stop_reason: bedrock_stop_reason, + usage: BedrockTokenUsage { + input_tokens: 5, + output_tokens: 10, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + assert_eq!( + anthropic_response.stop_reason, + expected_anthropic_stop_reason + ); + } + } + + #[test] + fn test_bedrock_to_anthropic_with_cache_tokens() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "Cached response".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 100, + output_tokens: 50, + total_tokens: 150, + cache_write_input_tokens: Some(20), + cache_read_input_tokens: Some(10), + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(anthropic_response.usage.input_tokens, 100); + assert_eq!(anthropic_response.usage.output_tokens, 50); + assert_eq!( + anthropic_response.usage.cache_creation_input_tokens, + Some(20) + ); + assert_eq!(anthropic_response.usage.cache_read_input_tokens, Some(10)); + } + + #[test] + fn test_bedrock_to_anthropic_with_tool_result() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "Here's the weather information:".to_string(), + }, + ContentBlock::ToolResult { + tool_result: crate::apis::amazon_bedrock::ToolResultBlock { + tool_use_id: "tool_67890".to_string(), + content: vec![ToolResultContentBlock::Text { + text: "Temperature: 72°F, Sunny".to_string(), + }], + status: Some(ToolResultStatus::Success), + }, + }, + ], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 20, + output_tokens: 35, + total_tokens: 55, + cache_write_input_tokens: None, + cache_read_input_tokens: None, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(anthropic_response.content.len(), 2); + + // Check text content + if let MessagesContentBlock::Text { text, .. } = &anthropic_response.content[0] { + assert_eq!(text, "Here's the weather information:"); + } else { + panic!("Expected text content block"); + } + + // Check tool result content + if let MessagesContentBlock::ToolResult { + tool_use_id, + content, + .. + } = &anthropic_response.content[1] + { + assert_eq!(tool_use_id, "tool_67890"); + if let ToolResultContent::Blocks(blocks) = content { + assert_eq!(blocks.len(), 1); + if let MessagesContentBlock::Text { text, .. } = &blocks[0] { + assert_eq!(text, "Temperature: 72°F, Sunny"); + } else { + panic!("Expected text content in tool result"); + } + } else { + panic!("Expected blocks in tool result content"); + } + } else { + panic!("Expected tool result content block"); + } + } + + #[test] + fn test_bedrock_to_anthropic_mixed_content() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "I can help with multiple tasks.".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_1".to_string(), + name: "search".to_string(), + input: json!({"query": "weather"}), + }, + }, + ContentBlock::Text { + text: "Let me also check another source.".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_2".to_string(), + name: "lookup".to_string(), + input: json!({"id": "12345"}), + }, + }, + ], + }, + }, + stop_reason: StopReason::ToolUse, + usage: BedrockTokenUsage { + input_tokens: 25, + output_tokens: 40, + total_tokens: 65, + cache_write_input_tokens: None, + cache_read_input_tokens: None, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(anthropic_response.content.len(), 4); + assert_eq!(anthropic_response.stop_reason, MessagesStopReason::ToolUse); + + // Verify the sequence: text -> tool_use -> text -> tool_use + if let MessagesContentBlock::Text { text, .. } = &anthropic_response.content[0] { + assert_eq!(text, "I can help with multiple tasks."); + } else { + panic!("Expected first content to be text"); + } + + if let MessagesContentBlock::ToolUse { id, name, .. } = &anthropic_response.content[1] { + assert_eq!(id, "tool_1"); + assert_eq!(name, "search"); + } else { + panic!("Expected second content to be tool use"); + } + + if let MessagesContentBlock::Text { text, .. } = &anthropic_response.content[2] { + assert_eq!(text, "Let me also check another source."); + } else { + panic!("Expected third content to be text"); + } + + if let MessagesContentBlock::ToolUse { id, name, .. } = &anthropic_response.content[3] { + assert_eq!(id, "tool_2"); + assert_eq!(name, "lookup"); + } else { + panic!("Expected fourth content to be tool use"); + } + } + + #[test] + fn test_convert_bedrock_message_to_anthropic_content() { + let bedrock_message = BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "Hello world!".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "test_tool".to_string(), + name: "test_function".to_string(), + input: json!({"param": "value"}), + }, + }, + ], + }; + + let content_blocks = + convert_bedrock_message_to_anthropic_content(&bedrock_message).unwrap(); + + assert_eq!(content_blocks.len(), 2); + + if let MessagesContentBlock::Text { text, .. } = &content_blocks[0] { + assert_eq!(text, "Hello world!"); + } else { + panic!("Expected text content block"); + } + + if let MessagesContentBlock::ToolUse { + id, name, input, .. + } = &content_blocks[1] + { + assert_eq!(id, "test_tool"); + assert_eq!(name, "test_function"); + assert_eq!(input["param"], "value"); + } else { + panic!("Expected tool use content block"); + } + } + + #[test] + fn test_bedrock_to_anthropic_role_conversion() { + // Test Assistant role + let assistant_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "I am an assistant".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 5, + output_tokens: 10, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = assistant_response.try_into().unwrap(); + assert_eq!(anthropic_response.role, MessagesRole::Assistant); + + // Test User role + let user_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::User, + content: vec![ContentBlock::Text { + text: "I am a user".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 5, + output_tokens: 10, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = user_response.try_into().unwrap(); + assert_eq!(anthropic_response.role, MessagesRole::User); + } + + #[test] + fn test_bedrock_to_anthropic_model_extraction() { + // Test model extraction from trace information + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "Test response".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 5, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: Some(ConverseTrace { + guardrail: None, + prompt_router: Some(PromptRouterTrace { + invoked_model_id: "anthropic.claude-3-sonnet-20240229-v1:0".to_string(), + }), + }), + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); + + // Should extract model ID from trace + assert_eq!( + anthropic_response.model, + "anthropic.claude-3-sonnet-20240229-v1:0" + ); + + // Test fallback when no trace information is available + let bedrock_response_no_trace = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "Test response".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 5, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let anthropic_response_fallback: MessagesResponse = + bedrock_response_no_trace.try_into().unwrap(); + + // Should use fallback model name + assert_eq!(anthropic_response_fallback.model, "bedrock-model"); + } +} diff --git a/crates/hermesllm/src/transforms/response/to_openai.rs b/crates/hermesllm/src/transforms/response/to_openai.rs new file mode 100644 index 00000000..acbdb420 --- /dev/null +++ b/crates/hermesllm/src/transforms/response/to_openai.rs @@ -0,0 +1,1171 @@ +use crate::apis::amazon_bedrock::{ + ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason, +}; +use crate::apis::anthropic::{ + MessagesContentBlock, MessagesContentDelta, MessagesResponse, MessagesStopReason, + MessagesStreamEvent, MessagesUsage, +}; +use crate::apis::openai::{ + ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason, + FunctionCallDelta, MessageContent, MessageDelta, ResponseMessage, Role, StreamChoice, + ToolCallDelta, Usage, +}; +use crate::clients::TransformError; +use crate::transforms::lib::*; + +// ============================================================================ +// MAIN RESPONSE TRANSFORMATIONS +// ============================================================================ + +// Usage Conversions +impl Into for MessagesUsage { + fn into(self) -> Usage { + Usage { + prompt_tokens: self.input_tokens, + completion_tokens: self.output_tokens, + total_tokens: self.input_tokens + self.output_tokens, + prompt_tokens_details: None, + completion_tokens_details: None, + } + } +} + +impl TryFrom for ChatCompletionsResponse { + type Error = TransformError; + + fn try_from(resp: MessagesResponse) -> Result { + let content = convert_anthropic_content_to_openai(&resp.content)?; + let finish_reason: FinishReason = resp.stop_reason.into(); + let tool_calls = resp.content.extract_tool_calls()?; + + // Convert MessageContent to String for response + let content_string = match content { + MessageContent::Text(text) => Some(text), + MessageContent::Parts(parts) => { + let text = parts.extract_text(); + if text.is_empty() { + None + } else { + Some(text) + } + } + }; + + let message = ResponseMessage { + role: Role::Assistant, + content: content_string, + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls, + }; + + let choice = Choice { + index: 0, + message, + finish_reason: Some(finish_reason), + logprobs: None, + }; + + let usage = Usage { + prompt_tokens: resp.usage.input_tokens, + completion_tokens: resp.usage.output_tokens, + total_tokens: resp.usage.input_tokens + resp.usage.output_tokens, + prompt_tokens_details: None, + completion_tokens_details: None, + }; + + Ok(ChatCompletionsResponse { + id: resp.id, + object: Some("chat.completion".to_string()), + created: current_timestamp(), + model: resp.model, + choices: vec![choice], + usage, + system_fingerprint: None, + service_tier: None, + }) + } +} + +impl TryFrom for ChatCompletionsResponse { + type Error = TransformError; + + fn try_from(resp: ConverseResponse) -> Result { + // Extract the message from the ConverseOutput + let message = match resp.output { + ConverseOutput::Message { message } => message, + }; + + // Convert Bedrock ConversationRole to OpenAI Role + let role = match message.role { + crate::apis::amazon_bedrock::ConversationRole::User => Role::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant, + }; + + // Convert Bedrock message content to OpenAI format + let (content, tool_calls) = convert_bedrock_message_to_openai(&message)?; + + // Convert Bedrock stop reason to OpenAI finish reason + let finish_reason = match resp.stop_reason { + StopReason::EndTurn => FinishReason::Stop, + StopReason::ToolUse => FinishReason::ToolCalls, + StopReason::MaxTokens => FinishReason::Length, + StopReason::StopSequence => FinishReason::Stop, + StopReason::GuardrailIntervened => FinishReason::ContentFilter, + StopReason::ContentFiltered => FinishReason::ContentFilter, + }; + + // Create response message + let response_message = ResponseMessage { + role, + content, + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls, + }; + + // Create choice + let choice = Choice { + index: 0, + message: response_message, + finish_reason: Some(finish_reason), + logprobs: None, + }; + + // Convert token usage + let usage = Usage { + prompt_tokens: resp.usage.input_tokens, + completion_tokens: resp.usage.output_tokens, + total_tokens: resp.usage.total_tokens, + prompt_tokens_details: None, + completion_tokens_details: None, + }; + + // Generate a response ID (using timestamp since Bedrock doesn't provide one) + let id = format!( + "bedrock-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() + ); + + // Extract model ID from trace information if available, otherwise use fallback + let model = resp + .trace + .as_ref() + .and_then(|trace| trace.prompt_router.as_ref()) + .map(|router| router.invoked_model_id.clone()) + .unwrap_or_else(|| "bedrock-model".to_string()); + + Ok(ChatCompletionsResponse { + id, + object: Some("chat.completion".to_string()), + created: current_timestamp(), + model, + choices: vec![choice], + usage, + system_fingerprint: None, + service_tier: None, + }) + } +} + +// ============================================================================ +// STREAMING TRANSFORMATIONS +// ============================================================================ + +impl TryFrom for ChatCompletionsStreamResponse { + type Error = TransformError; + + fn try_from(event: MessagesStreamEvent) -> Result { + match event { + MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk( + &message.id, + &message.model, + MessageDelta { + role: Some(Role::Assistant), + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + + MessagesStreamEvent::ContentBlockStart { content_block, .. } => { + convert_content_block_start(content_block) + } + + MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta), + + MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()), + + MessagesStreamEvent::MessageDelta { delta, usage } => { + let finish_reason: Option = Some(delta.stop_reason.into()); + let openai_usage: Option = Some(usage.into()); + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + finish_reason, + openai_usage, + )) + } + + MessagesStreamEvent::MessageStop => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + Some(FinishReason::Stop), + None, + )), + + MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse { + id: "stream".to_string(), + object: Some("chat.completion.chunk".to_string()), + created: current_timestamp(), + model: "unknown".to_string(), + choices: vec![], + usage: None, + system_fingerprint: None, + service_tier: None, + }), + } + } +} + +impl TryFrom for ChatCompletionsStreamResponse { + type Error = TransformError; + + fn try_from(event: ConverseStreamEvent) -> Result { + match event { + ConverseStreamEvent::MessageStart(start_event) => { + let role = match start_event.role { + crate::apis::amazon_bedrock::ConversationRole::User => Role::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant, + }; + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: Some(role), + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )) + } + + ConverseStreamEvent::ContentBlockStart(start_event) => { + use crate::apis::amazon_bedrock::ContentBlockStart; + + match start_event.start { + ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: start_event.content_block_index as u32, + id: Some(tool_use.tool_use_id), + call_type: Some("function".to_string()), + function: Some(FunctionCallDelta { + name: Some(tool_use.name), + arguments: Some("".to_string()), + }), + }]), + }, + None, + None, + )), + } + } + + ConverseStreamEvent::ContentBlockDelta(delta_event) => { + use crate::apis::amazon_bedrock::ContentBlockDelta; + + match delta_event.delta { + ContentBlockDelta::Text { text } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(text), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: delta_event.content_block_index as u32, + id: None, + call_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: Some(tool_use.input), + }), + }]), + }, + None, + None, + )), + } + } + + ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()), + + ConverseStreamEvent::MessageStop(stop_event) => { + let finish_reason = match stop_event.stop_reason { + StopReason::EndTurn => FinishReason::Stop, + StopReason::ToolUse => FinishReason::ToolCalls, + StopReason::MaxTokens => FinishReason::Length, + StopReason::StopSequence => FinishReason::Stop, + StopReason::GuardrailIntervened => FinishReason::ContentFilter, + StopReason::ContentFiltered => FinishReason::ContentFilter, + }; + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + Some(finish_reason), + None, + )) + } + + ConverseStreamEvent::Metadata(metadata_event) => { + let usage = Usage { + prompt_tokens: metadata_event.usage.input_tokens, + completion_tokens: metadata_event.usage.output_tokens, + total_tokens: metadata_event.usage.total_tokens, + prompt_tokens_details: None, + completion_tokens_details: None, + }; + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + Some(usage), + )) + } + + // Error events - convert to empty chunks (errors should be handled elsewhere) + ConverseStreamEvent::InternalServerException(_) + | ConverseStreamEvent::ModelStreamErrorException(_) + | ConverseStreamEvent::ServiceUnavailableException(_) + | ConverseStreamEvent::ThrottlingException(_) + | ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()), + } + } +} + +/// Convert content block start to OpenAI chunk +fn convert_content_block_start( + content_block: MessagesContentBlock, +) -> Result { + match content_block { + MessagesContentBlock::Text { .. } => { + // No immediate output for text block start + Ok(create_empty_openai_chunk()) + } + MessagesContentBlock::ToolUse { id, name, .. } + | MessagesContentBlock::ServerToolUse { id, name, .. } + | MessagesContentBlock::McpToolUse { id, name, .. } => { + // Tool use start → OpenAI chunk with tool_calls + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: 0, + id: Some(id), + call_type: Some("function".to_string()), + function: Some(FunctionCallDelta { + name: Some(name), + arguments: Some("".to_string()), + }), + }]), + }, + None, + None, + )) + } + _ => Err(TransformError::UnsupportedContent( + "Unsupported content block type in stream start".to_string(), + )), + } +} + +/// Convert content delta to OpenAI chunk +fn convert_content_delta( + delta: MessagesContentDelta, +) -> Result { + match delta { + MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(text), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(format!("thinking: {}", thinking)), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: 0, + id: None, + call_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: Some(partial_json), + }), + }]), + }, + None, + None, + )), + } +} + +/// Helper to create OpenAI streaming chunk +fn create_openai_chunk( + id: &str, + model: &str, + delta: MessageDelta, + finish_reason: Option, + usage: Option, +) -> ChatCompletionsStreamResponse { + ChatCompletionsStreamResponse { + id: id.to_string(), + object: Some("chat.completion.chunk".to_string()), + created: current_timestamp(), + model: model.to_string(), + choices: vec![StreamChoice { + index: 0, + delta, + finish_reason, + logprobs: None, + }], + usage, + system_fingerprint: None, + service_tier: None, + } +} + +/// Helper to create empty OpenAI streaming chunk +fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse { + create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + ) +} + +/// Convert Anthropic content blocks to OpenAI message content +fn convert_anthropic_content_to_openai( + content: &[MessagesContentBlock], +) -> Result { + let mut text_parts = Vec::new(); + + for block in content { + match block { + MessagesContentBlock::Text { text, .. } => { + text_parts.push(text.clone()); + } + MessagesContentBlock::Thinking { thinking, .. } => { + text_parts.push(format!("thinking: {}", thinking)); + } + _ => { + // Skip other content types for basic text conversion + continue; + } + } + } + + Ok(MessageContent::Text(text_parts.join("\n"))) +} + +// Stop Reason Conversions +impl Into for MessagesStopReason { + fn into(self) -> FinishReason { + match self { + MessagesStopReason::EndTurn => FinishReason::Stop, + MessagesStopReason::MaxTokens => FinishReason::Length, + MessagesStopReason::StopSequence => FinishReason::Stop, + MessagesStopReason::ToolUse => FinishReason::ToolCalls, + MessagesStopReason::PauseTurn => FinishReason::Stop, + MessagesStopReason::Refusal => FinishReason::ContentFilter, + } + } +} + +/// Convert Bedrock Message to OpenAI content and tool calls +/// This function extracts text content and tool calls from a Bedrock message +fn convert_bedrock_message_to_openai( + message: &crate::apis::amazon_bedrock::Message, +) -> Result<(Option, Option>), TransformError> { + use crate::apis::amazon_bedrock::ContentBlock; + use crate::apis::openai::{FunctionCall, ToolCall}; + + let mut text_content = String::new(); + let mut tool_calls = Vec::new(); + + for content_block in &message.content { + match content_block { + ContentBlock::Text { text } => { + text_content.push_str(text); + } + ContentBlock::ToolUse { tool_use } => { + tool_calls.push(ToolCall { + id: tool_use.tool_use_id.clone(), + call_type: "function".to_string(), + function: FunctionCall { + name: tool_use.name.clone(), + arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(), + }, + }); + } + _ => continue, + } + } + + let content = if text_content.is_empty() { + None + } else { + Some(text_content) + }; + let tool_calls = if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }; + + Ok((content, tool_calls)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::apis::amazon_bedrock::{ + BedrockTokenUsage, ContentBlock, ConversationRole, ConverseOutput, ConverseResponse, + ConverseTrace, Message as BedrockMessage, PromptRouterTrace, StopReason, + }; + use crate::apis::openai::{ChatCompletionsResponse, FinishReason, Role}; + use serde_json::json; + + #[test] + fn test_bedrock_to_openai_basic_response() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "Hello! How can I help you today?".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 25, + total_tokens: 35, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + assert_eq!(openai_response.object, Some("chat.completion".to_string())); + assert_eq!(openai_response.model, "bedrock-model"); + assert!(openai_response.id.starts_with("bedrock-")); + + // Check choices + assert_eq!(openai_response.choices.len(), 1); + let choice = &openai_response.choices[0]; + assert_eq!(choice.index, 0); + assert_eq!(choice.message.role, Role::Assistant); + assert_eq!( + choice.message.content, + Some("Hello! How can I help you today?".to_string()) + ); + assert_eq!(choice.finish_reason, Some(FinishReason::Stop)); + assert!(choice.message.tool_calls.is_none()); + + // Check usage + assert_eq!(openai_response.usage.prompt_tokens, 10); + assert_eq!(openai_response.usage.completion_tokens, 25); + assert_eq!(openai_response.usage.total_tokens, 35); + } + + #[test] + fn test_bedrock_to_openai_with_tool_use() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "I'll help you check the weather.".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_12345".to_string(), + name: "get_weather".to_string(), + input: json!({ + "location": "San Francisco" + }), + }, + }, + ], + }, + }, + stop_reason: StopReason::ToolUse, + usage: BedrockTokenUsage { + input_tokens: 15, + output_tokens: 30, + total_tokens: 45, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + assert_eq!( + openai_response.choices[0].finish_reason, + Some(FinishReason::ToolCalls) + ); + assert_eq!( + openai_response.choices[0].message.content, + Some("I'll help you check the weather.".to_string()) + ); + + // Check tool calls + let tool_calls = openai_response.choices[0] + .message + .tool_calls + .as_ref() + .unwrap(); + assert_eq!(tool_calls.len(), 1); + + let tool_call = &tool_calls[0]; + assert_eq!(tool_call.id, "tool_12345"); + assert_eq!(tool_call.call_type, "function"); + assert_eq!(tool_call.function.name, "get_weather"); + + let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments).unwrap(); + assert_eq!(args["location"], "San Francisco"); + } + + #[test] + fn test_bedrock_to_openai_stop_reason_conversions() { + let test_cases = vec![ + (StopReason::EndTurn, FinishReason::Stop), + (StopReason::ToolUse, FinishReason::ToolCalls), + (StopReason::MaxTokens, FinishReason::Length), + (StopReason::StopSequence, FinishReason::Stop), + (StopReason::GuardrailIntervened, FinishReason::ContentFilter), + (StopReason::ContentFiltered, FinishReason::ContentFilter), + ]; + + for (bedrock_stop_reason, expected_openai_finish_reason) in test_cases { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "Test response".to_string(), + }], + }, + }, + stop_reason: bedrock_stop_reason, + usage: BedrockTokenUsage { + input_tokens: 5, + output_tokens: 10, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + assert_eq!( + openai_response.choices[0].finish_reason, + Some(expected_openai_finish_reason) + ); + } + } + + #[test] + fn test_bedrock_to_openai_multiple_tool_calls() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "I'll help with multiple tasks.".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_1".to_string(), + name: "search".to_string(), + input: json!({"query": "weather"}), + }, + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_2".to_string(), + name: "lookup".to_string(), + input: json!({"id": "12345"}), + }, + }, + ], + }, + }, + stop_reason: StopReason::ToolUse, + usage: BedrockTokenUsage { + input_tokens: 25, + output_tokens: 40, + total_tokens: 65, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + assert_eq!( + openai_response.choices[0].finish_reason, + Some(FinishReason::ToolCalls) + ); + assert_eq!( + openai_response.choices[0].message.content, + Some("I'll help with multiple tasks.".to_string()) + ); + + // Check multiple tool calls + let tool_calls = openai_response.choices[0] + .message + .tool_calls + .as_ref() + .unwrap(); + assert_eq!(tool_calls.len(), 2); + + // First tool call + assert_eq!(tool_calls[0].id, "tool_1"); + assert_eq!(tool_calls[0].function.name, "search"); + let args1: serde_json::Value = + serde_json::from_str(&tool_calls[0].function.arguments).unwrap(); + assert_eq!(args1["query"], "weather"); + + // Second tool call + assert_eq!(tool_calls[1].id, "tool_2"); + assert_eq!(tool_calls[1].function.name, "lookup"); + let args2: serde_json::Value = + serde_json::from_str(&tool_calls[1].function.arguments).unwrap(); + assert_eq!(args2["id"], "12345"); + } + + #[test] + fn test_bedrock_to_openai_mixed_content() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "First part. ".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_mid".to_string(), + name: "calculate".to_string(), + input: json!({"expr": "2+2"}), + }, + }, + ContentBlock::Text { + text: "Second part.".to_string(), + }, + ], + }, + }, + stop_reason: StopReason::ToolUse, + usage: BedrockTokenUsage { + input_tokens: 20, + output_tokens: 35, + total_tokens: 55, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + // Content should be combined text parts (no separator added) + assert_eq!( + openai_response.choices[0].message.content, + Some("First part. Second part.".to_string()) + ); + + // Should have one tool call + let tool_calls = openai_response.choices[0] + .message + .tool_calls + .as_ref() + .unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id, "tool_mid"); + assert_eq!(tool_calls[0].function.name, "calculate"); + } + + #[test] + fn test_bedrock_to_openai_empty_content() { + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "tool_only".to_string(), + name: "action".to_string(), + input: json!({}), + }, + }], + }, + }, + stop_reason: StopReason::ToolUse, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 5, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + // Content should be None when there's no text + assert_eq!(openai_response.choices[0].message.content, None); + + // Should have tool call + let tool_calls = openai_response.choices[0] + .message + .tool_calls + .as_ref() + .unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id, "tool_only"); + } + + #[test] + fn test_convert_bedrock_message_to_openai() { + let bedrock_message = BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "Hello world!".to_string(), + }, + ContentBlock::ToolUse { + tool_use: crate::apis::amazon_bedrock::ToolUseBlock { + tool_use_id: "test_tool".to_string(), + name: "test_function".to_string(), + input: json!({"param": "value"}), + }, + }, + ], + }; + + let (content, tool_calls) = convert_bedrock_message_to_openai(&bedrock_message).unwrap(); + + assert_eq!(content, Some("Hello world!".to_string())); + + let tool_calls = tool_calls.unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id, "test_tool"); + assert_eq!(tool_calls[0].function.name, "test_function"); + + let args: serde_json::Value = + serde_json::from_str(&tool_calls[0].function.arguments).unwrap(); + assert_eq!(args["param"], "value"); + } + + #[test] + fn test_bedrock_to_openai_role_conversion() { + // Test Assistant role + let assistant_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "I am an assistant".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 5, + output_tokens: 10, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = assistant_response.try_into().unwrap(); + assert_eq!(openai_response.choices[0].message.role, Role::Assistant); + + // Test User role + let user_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::User, + content: vec![ContentBlock::Text { + text: "I am a user".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 5, + output_tokens: 10, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = user_response.try_into().unwrap(); + assert_eq!(openai_response.choices[0].message.role, Role::User); + } + + #[test] + fn test_bedrock_to_openai_model_extraction() { + // Test model extraction from trace information + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "Test response".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 5, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: Some(ConverseTrace { + guardrail: None, + prompt_router: Some(PromptRouterTrace { + invoked_model_id: "anthropic.claude-3-sonnet-20240229-v1:0".to_string(), + }), + }), + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + // Should extract model ID from trace + assert_eq!( + openai_response.model, + "anthropic.claude-3-sonnet-20240229-v1:0" + ); + + // Test fallback when no trace information is available + let bedrock_response_no_trace = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ContentBlock::Text { + text: "Test response".to_string(), + }], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 10, + output_tokens: 5, + total_tokens: 15, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response_fallback: ChatCompletionsResponse = + bedrock_response_no_trace.try_into().unwrap(); + + // Should use fallback model name + assert_eq!(openai_response_fallback.model, "bedrock-model"); + } + + #[test] + fn test_bedrock_to_openai_with_multimedia_content() { + use crate::apis::amazon_bedrock::ImageSource; + + let bedrock_response = ConverseResponse { + output: ConverseOutput::Message { + message: BedrockMessage { + role: ConversationRole::Assistant, + content: vec![ + ContentBlock::Text { + text: "Here's the analysis:".to_string(), + }, + ContentBlock::Image { + image: crate::apis::amazon_bedrock::ImageBlock { + source: ImageSource::Base64 { + media_type: "image/jpeg".to_string(), + data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(), + }, + }, + } + ], + }, + }, + stop_reason: StopReason::EndTurn, + usage: BedrockTokenUsage { + input_tokens: 50, + output_tokens: 75, + total_tokens: 125, + ..Default::default() + }, + metrics: None, + trace: None, + additional_model_response_fields: None, + performance_config: None, + }; + + let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); + + assert_eq!( + openai_response.choices[0].finish_reason, + Some(FinishReason::Stop) + ); + + let content = openai_response.choices[0].message.content.as_ref().unwrap(); + + // Check that text content is preserved (image blocks are currently ignored) + assert!(content.contains("Here's the analysis:")); + // Note: Image blocks are not converted to text in the current implementation + } +} diff --git a/crates/llm_gateway/Cargo.toml b/crates/llm_gateway/Cargo.toml index b2557477..281e05be 100644 --- a/crates/llm_gateway/Cargo.toml +++ b/crates/llm_gateway/Cargo.toml @@ -23,6 +23,7 @@ thiserror = "1.0.64" derivative = "2.2.0" sha2 = "0.10.8" hermesllm = { version = "0.1.0", path = "../hermesllm" } +bytes = "1.10" [dev-dependencies] serial_test = "3.1.1" diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index f11134cb..785b5b72 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,3 +1,4 @@ +use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::hostcalls::get_current_time; @@ -12,8 +13,8 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use crate::metrics::Metrics; use common::configuration::{LlmProvider, LlmProviderType, Overrides}; use common::consts::{ - ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH, RATELIMIT_SELECTOR_HEADER_KEY, - REQUEST_ID_HEADER, TRACE_PARENT_HEADER, + ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH, + RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, }; use common::errors::ServerError; use common::llm_providers::LlmProviders; @@ -21,9 +22,15 @@ use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; +use hermesllm::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; +use hermesllm::apis::anthropic::{MessagesContentBlock, MessagesStreamEvent}; +use hermesllm::apis::sse::{SseEvent, SseStreamIter}; use hermesllm::clients::endpoints::SupportedAPIs; -use hermesllm::providers::response::{ProviderResponse, SseEvent, SseStreamIter}; -use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType}; +use hermesllm::providers::response::ProviderResponse; +use hermesllm::{ + DecodedFrame, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType, + ProviderStreamResponseType, +}; pub struct StreamContext { metrics: Rc, @@ -33,7 +40,7 @@ pub struct StreamContext { /// The API that is requested by the client (before compatibility mapping) client_api: Option, /// The API that should be used for the upstream provider (after compatibility mapping) - resolved_api: Option, + resolved_api: Option, llm_providers: Rc, llm_provider: Option>, request_id: Option, @@ -45,8 +52,8 @@ pub struct StreamContext { traces_queue: Arc>>, overrides: Rc>, user_message: Option, - /// Store upstream response status code to handle error responses gracefully upstream_status_code: Option, + binary_frame_decoder: Option>, } impl StreamContext { @@ -75,6 +82,7 @@ impl StreamContext { request_body_sent_time: None, user_message: None, upstream_status_code: None, + binary_frame_decoder: None, } } @@ -108,6 +116,7 @@ impl StreamContext { .model .as_ref() .unwrap_or(&"".to_string()), + self.streaming_response, ); if target_endpoint != request_path { self.set_http_request_header(":path", Some(&target_endpoint)); @@ -148,14 +157,19 @@ impl StreamContext { // Set API-specific headers based on the resolved upstream API match self.resolved_api.as_ref() { - Some(SupportedAPIs::AnthropicMessagesAPI(_)) => { + Some(SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => { // Anthropic API requires x-api-key and anthropic-version headers // Remove any existing Authorization header since Anthropic doesn't use it self.remove_http_request_header("Authorization"); self.set_http_request_header("x-api-key", Some(llm_provider_api_key_value)); self.set_http_request_header("anthropic-version", Some("2023-06-01")); } - Some(SupportedAPIs::OpenAIChatCompletions(_)) | None => { + Some( + SupportedUpstreamAPIs::OpenAIChatCompletions(_) + | SupportedUpstreamAPIs::AmazonBedrockConverse(_) + | SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + ) + | None => { // OpenAI and default: use Authorization Bearer token // Remove any existing x-api-key header since OpenAI doesn't use it self.remove_http_request_header("x-api-key"); @@ -410,7 +424,16 @@ impl StreamContext { match self.client_api.as_ref() { Some(client_api) => { let client_api = client_api.clone(); // Clone to avoid borrowing issues - let upstream_api = provider_id.compatible_api_for_client(&client_api); + let upstream_api = + provider_id.compatible_api_for_client(&client_api, self.streaming_response); + + // Check if this is Bedrock binary stream + if matches!( + upstream_api, + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_) + ) { + return self.handle_bedrock_binary_stream(body, &client_api, &upstream_api); + } // Parse body into SSE iterator using TryFrom let sse_iter: SseStreamIter> = @@ -487,6 +510,127 @@ impl StreamContext { } } + fn handle_bedrock_binary_stream( + &mut self, + body: &[u8], + client_api: &SupportedAPIs, + upstream_api: &SupportedUpstreamAPIs, + ) -> Result, Action> { + // Initialize decoder if not present + if self.binary_frame_decoder.is_none() { + self.binary_frame_decoder = Some(BedrockBinaryFrameDecoder::from_bytes(&[])); + } + + // Add incoming bytes to buffer + let decoder = self.binary_frame_decoder.as_mut().unwrap(); + decoder.buffer_mut().extend_from_slice(body); + + let mut response_buffer = Vec::new(); + loop { + let decoded_frame = self.binary_frame_decoder.as_mut().unwrap().decode_frame(); + match decoded_frame { + Some(DecodedFrame::Complete(ref frame_ref)) => { + let frame = DecodedFrame::Complete(frame_ref.clone()); + match ProviderStreamResponseType::try_from((&frame, client_api, upstream_api)) { + Ok(provider_response) => { + self.record_ttft_if_needed(); + + // Handle ContentBlockStart and ContentBlockDelta events + match &provider_response { + ProviderStreamResponseType::MessagesStreamEvent(evt) => { + match evt { + MessagesStreamEvent::ContentBlockStart { + index, .. + } => { + // Mark that we've seen ContentBlockStart for this index + self.binary_frame_decoder + .as_mut() + .unwrap() + .set_content_block_start_sent(*index as i32); + debug!( + "[ARCHGW_REQ_ID:{}] BEDROCK_CONTENT_BLOCK_START_TRACKED: index={}", + self.request_identifier(), + *index + ); + } + MessagesStreamEvent::ContentBlockDelta { + index, .. + } => { + // Check if ContentBlockStart was sent for this index + let needs_start = !self + .binary_frame_decoder + .as_ref() + .unwrap() + .has_content_block_start_been_sent(*index as i32); + + if needs_start { + // Emit empty ContentBlockStart before delta + let content_block_start = + MessagesStreamEvent::ContentBlockStart { + index: *index, + content_block: MessagesContentBlock::Text { + text: String::new(), + cache_control: None, + }, + }; + let start_sse: String = content_block_start.into(); + response_buffer + .extend_from_slice(start_sse.as_bytes()); + + // Mark that we've now sent it + self.binary_frame_decoder + .as_mut() + .unwrap() + .set_content_block_start_sent(*index as i32); + + debug!( + "[ARCHGW_REQ_ID:{}] BEDROCK_INJECTED_CONTENT_BLOCK_START: index={}", + self.request_identifier(), + *index + ); + } + } + _ => {} + } + } + _ => {} + } + + let sse_string: String = provider_response.into(); + response_buffer.extend_from_slice(sse_string.as_bytes()); + } + Err(e) => { + warn!( + "[ARCHGW_REQ_ID:{}] BEDROCK_FRAME_CONVERSION_ERROR: {}", + self.request_identifier(), + e + ); + } + } + } + Some(DecodedFrame::Incomplete) => { + // Incomplete frame - buffer retains partial data, wait for more bytes + debug!( + "[ARCHGW_REQ_ID:{}] BEDROCK_INCOMPLETE_FRAME: waiting for more data", + self.request_identifier() + ); + break; + } + None => { + // Decode error + warn!( + "[ARCHGW_REQ_ID:{}] BEDROCK_DECODE_ERROR", + self.request_identifier() + ); + return Err(Action::Continue); + } + } + } + + // Return accumulated complete frames (may be empty if all frames incomplete) + Ok(response_buffer) + } + fn handle_non_streaming_response( &mut self, body: &[u8], @@ -578,6 +722,11 @@ impl HttpContext for StreamContext { return Action::Continue; } + self.streaming_response = self + .get_http_request_header(ARCH_IS_STREAMING_HEADER) + .map(|val| val == "true") + .unwrap_or(false); + let use_agent_orchestrator = match self.overrides.as_ref() { Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(), None => false, @@ -612,7 +761,17 @@ impl HttpContext for StreamContext { (self.client_api.as_ref(), self.llm_provider.as_ref()) { let provider_id = provider.to_provider_id(); - self.resolved_api = Some(provider_id.compatible_api_for_client(api)); + self.resolved_api = + Some(provider_id.compatible_api_for_client(api, self.streaming_response)); + + debug!( + "[ARCHGW_REQ_ID:{}] ROUTING_INFO: provider='{}' client_api={:?} resolved_api={:?} request_path='{}'", + self.request_identifier(), + provider.to_provider_id(), + api, + self.resolved_api, + request_path + ); } else { self.resolved_api = None; } @@ -697,7 +856,7 @@ impl HttpContext for StreamContext { //We need to deserialize the request body based on the resolved API let mut deserialized_client_request: ProviderRequestType = match self.client_api.as_ref() { Some(the_client_api) => { - debug!( + info!( "[ARCHGW_REQ_ID:{}] CLIENT_REQUEST_RECEIVED: api={:?} body_size={}", self.request_identifier(), the_client_api, @@ -795,7 +954,10 @@ impl HttpContext for StreamContext { ); // Use provider interface for streaming detection and setup - self.streaming_response = deserialized_client_request.is_streaming(); + // If streaming_response is not already set from headers, get it from the parsed request + if !self.streaming_response { + self.streaming_response = deserialized_client_request.is_streaming(); + } // Use provider interface for text extraction (after potential mutation) let input_tokens_str = deserialized_client_request.extract_messages_text(); diff --git a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml index 794ed117..0a626be9 100644 --- a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml +++ b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml @@ -35,6 +35,10 @@ llm_providers: access_key: $AZURE_API_KEY base_url: https://katanemo.openai.azure.com + - model: amazon_bedrock/us.amazon.nova-premier-v1:0 + access_key: $AWS_BEARER_TOKEN_BEDROCK + base_url: https://bedrock-runtime.us-west-2.amazonaws.com + # Ollama Models - model: ollama/llama3.1 base_url: http://host.docker.internal:11434 @@ -71,3 +75,6 @@ model_aliases: creative-model: target: claude-sonnet-4-20250514 + + coding-model: + target: us.amazon.nova-premier-v1:0 diff --git a/docs/source/concepts/llm_providers/supported_providers.rst b/docs/source/concepts/llm_providers/supported_providers.rst index a33e58f8..2a58f328 100644 --- a/docs/source/concepts/llm_providers/supported_providers.rst +++ b/docs/source/concepts/llm_providers/supported_providers.rst @@ -517,6 +517,36 @@ Azure OpenAI access_key: $AZURE_OPENAI_API_KEY base_url: https://your-resource.openai.azure.com +Amazon Bedrock +~~~~~~~~~~~~~~ + +**Provider Prefix:** ``amazon_bedrock/`` + +**API Endpoint:** Arch automatically constructs the endpoint as: + - Non-streaming: ``/model/{model-id}/converse`` + - Streaming: ``/model/{model-id}/converse-stream`` + +**Authentication:** AWS Bearer Token + Base URL - Get your API Keys from `AWS Bedrock Console `_ → Discover → API Keys. + +**Supported Chat Models:** All Amazon Bedrock foundation models including Claude (Anthropic), Nova (Amazon), Llama (Meta), Mistral AI, and Cohere Command models. + +.. code-block:: yaml + + llm_providers: + # Amazon Nova models + - model: amazon_bedrock/us.amazon.nova-premier-v1:0 + access_key: $AWS_BEARER_TOKEN_BEDROCK + base_url: https://bedrock-runtime.us-west-2.amazonaws.com + default: true + + - model: amazon_bedrock/us.amazon.nova-pro-v1:0 + access_key: $AWS_BEARER_TOKEN_BEDROCK + base_url: https://bedrock-runtime.us-west-2.amazonaws.com + + # Claude on Bedrock + - model: amazon_bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0 + access_key: $AWS_BEARER_TOKEN_BEDROCK + base_url: https://bedrock-runtime.us-west-2.amazonaws.com Qwen (Alibaba) ~~~~~~~~~~~~~~ @@ -540,8 +570,7 @@ Qwen (Alibaba) # Multiple deployments - model: qwen/qwen3-coder access_key: $DASHSCOPE_API_KEY - base_url: "https://dashscope-intl.aliyuncs.com", - + base_url: "https://dashscope-intl.aliyuncs.com" Ollama ~~~~~~ diff --git a/tests/e2e/response.hex b/tests/e2e/response.hex new file mode 100644 index 00000000..c96504e2 Binary files /dev/null and b/tests/e2e/response.hex differ diff --git a/tests/e2e/response_with_tools.hex b/tests/e2e/response_with_tools.hex new file mode 100644 index 00000000..5aa41165 Binary files /dev/null and b/tests/e2e/response_with_tools.hex differ diff --git a/tests/e2e/test_model_alias_routing.py b/tests/e2e/test_model_alias_routing.py index d5a289a6..5b1d3719 100644 --- a/tests/e2e/test_model_alias_routing.py +++ b/tests/e2e/test_model_alias_routing.py @@ -403,3 +403,381 @@ def test_anthropic_thinking_mode_streaming(): final_block_types = [blk.type for blk in final.content] assert "text" in final_block_types assert "thinking" in final_block_types + + +def test_openai_client_with_coding_model_alias_and_tools(): + """Test OpenAI client using 'coding-model' alias (maps to Bedrock) with coding question and tools""" + logger.info("Testing OpenAI client with 'coding-model' alias -> Bedrock with tools") + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI( + api_key="test-key", + base_url=f"{base_url}/v1", + ) + + completion = client.chat.completions.create( + model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0 + max_tokens=1000, + messages=[ + { + "role": "user", + "content": "I need to write a Python function that calculates the factorial of a number. Can you help me write and run it?", + } + ], + tools=[ + { + "type": "function", + "function": { + "name": "run_python_code", + "description": "Execute Python code and return the result", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute", + } + }, + "required": ["code"], + }, + }, + } + ], + tool_choice="auto", + ) + + response_content = completion.choices[0].message.content + tool_calls = completion.choices[0].message.tool_calls + # Should get either text response or tool calls for coding assistance + assert response_content is not None or ( + tool_calls is not None and len(tool_calls) > 0 + ) + + +def test_anthropic_client_with_coding_model_alias_and_tools(): + """Test Anthropic client using 'coding-model' alias (maps to Bedrock) with coding question and tools""" + logger.info( + "Testing Anthropic client with 'coding-model' alias -> Bedrock with tools" + ) + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = anthropic.Anthropic(api_key="test-key", base_url=base_url) + + message = client.messages.create( + model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0 + max_tokens=1000, + messages=[ + { + "role": "user", + "content": "I need to write a Python function that calculates the factorial of a number. Can you help me write and run it?", + } + ], + tools=[ + { + "name": "run_python_code", + "description": "Execute Python code and return the result", + "input_schema": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute", + } + }, + "required": ["code"], + }, + } + ], + tool_choice={"type": "auto"}, + ) + + text_content = "".join(b.text for b in message.content if b.type == "text") + tool_use_blocks = [b for b in message.content if b.type == "tool_use"] + + logger.info(f"Response from coding-model alias via Anthropic: {text_content}") + logger.info(f"Tool use blocks: {len(tool_use_blocks)}") + + # Should get either text response or tool use blocks for coding assistance + assert text_content or len(tool_use_blocks) > 0 + + +@pytest.mark.flaky(retries=0) # Disable retries to see the actual failure +def test_anthropic_client_with_coding_model_alias_and_tools_streaming(): + """Test Anthropic client using 'coding-model' alias (maps to Bedrock) with coding question and tools - streaming""" + logger.info( + "Testing Anthropic client with 'coding-model' alias -> Bedrock with tools (streaming)" + ) + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = anthropic.Anthropic(api_key="test-key", base_url=base_url) + + text_chunks = [] + tool_use_blocks = [] + all_events = [] # Capture all events for debugging + + try: + with client.messages.stream( + model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0 + max_tokens=1000, + messages=[ + { + "role": "user", + "content": "I need to write a Python function that calculates the factorial of a number. Can you help me write and run it?", + } + ], + tools=[ + { + "name": "run_python_code", + "description": "Execute Python code and return the result", + "input_schema": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute", + } + }, + "required": ["code"], + }, + } + ], + tool_choice={"type": "auto"}, + ) as stream: + for event in stream: + # Extract index if available + index = getattr(event, "index", None) + + # Log and capture all events for debugging + all_events.append( + {"type": event.type, "index": index, "event": str(event)[:200]} + ) + logger.info(f"Event #{len(all_events)}: {event.type} [index={index}]") + + # Collect text deltas + if event.type == "content_block_delta" and hasattr(event, "delta"): + if event.delta.type == "text_delta": + text_chunks.append(event.delta.text) + + # Collect tool use blocks + if event.type == "content_block_start" and hasattr( + event, "content_block" + ): + if event.content_block.type == "tool_use": + tool_use_blocks.append(event.content_block) + + final_message = stream.get_final_message() + except Exception as e: + logger.error(f"Exception during streaming: {type(e).__name__}: {e}") + logger.error(f"Events received before error: {len(all_events)}") + logger.error(f"Text chunks collected: {len(text_chunks)}") + logger.error(f"Tool use blocks collected: {len(tool_use_blocks)}") + logger.error("\nLast 20 events before crash:") + for evt in all_events[-20:]: + logger.error(f" {evt['type']:30s} index={evt['index']}") + raise + + full_text = "".join(text_chunks) + logger.info(f"Streaming response from coding-model with tools: {full_text}") + logger.info(f"Total events received: {len(all_events)}") + logger.info( + f"Text chunks: {len(text_chunks)}, Tool use blocks: {len(tool_use_blocks)}" + ) + + # Should get either text response or tool use blocks for coding assistance + # Modified assertion to be more lenient and provide better error messages + assert ( + full_text or len(tool_use_blocks) > 0 + ), f"Expected text or tool use. Got text_len={len(full_text)}, tools={len(tool_use_blocks)}, events={len(all_events)}" + + # Verify final message structure + assert final_message is not None, "Final message should not be None" + assert ( + final_message.content and len(final_message.content) > 0 + ), f"Final message should have content. Got: {final_message.content if final_message else 'None'}" + + +def test_anthropic_client_streaming_with_bedrock(): + """Test Anthropic client using 'coding-model' alias (maps to Bedrock) with streaming""" + logger.info( + "Testing Anthropic client with 'coding-model' alias -> Bedrock (streaming)" + ) + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = anthropic.Anthropic(api_key="test-key", base_url=base_url) + + text_chunks = [] + + with client.messages.stream( + model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0 + max_tokens=500, + messages=[ + { + "role": "user", + "content": "Write a short 4-line sonnet about coding.", + } + ], + ) as stream: + for event in stream: + # Collect text deltas + if event.type == "content_block_delta" and hasattr(event, "delta"): + if event.delta.type == "text_delta": + text_chunks.append(event.delta.text) + + final_message = stream.get_final_message() + + full_text = "".join(text_chunks) + logger.info(f"Response: {full_text}") + + # Should get a text response + assert len(full_text) > 0, "Expected text response from streaming" + + # Verify final message structure + assert final_message is not None + assert final_message.content and len(final_message.content) > 0 + + +def test_openai_client_streaming_with_bedrock(): + """Test OpenAI client using 'coding-model' alias (maps to Bedrock) with streaming""" + logger.info( + "Testing OpenAI client with 'coding-model' alias -> Bedrock (streaming)" + ) + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI( + api_key="test-key", + base_url=f"{base_url}/v1", + ) + + stream = client.chat.completions.create( + model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0 + max_tokens=500, + messages=[ + { + "role": "user", + "content": "Write a short 4-line sonnet about coding.", + } + ], + stream=True, + ) + + content_chunks = [] + for chunk in stream: + if chunk.choices and len(chunk.choices) > 0: + delta = chunk.choices[0].delta + if delta.content: + content_chunks.append(delta.content) + + full_content = "".join(content_chunks) + logger.info(f"Streaming response from coding-model: {full_content}") + + # Should get a text response + assert len(full_content) > 0, "Expected text response from streaming" + + +def test_openai_client_streaming_with_bedrock_and_tools(): + """Test OpenAI client using 'coding-model' alias (maps to Bedrock) with streaming and tools""" + logger.info( + "Testing OpenAI client with 'coding-model' alias -> Bedrock with tools (streaming)" + ) + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI( + api_key="test-key", + base_url=f"{base_url}/v1", + ) + + stream = client.chat.completions.create( + model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0 + max_tokens=1000, + messages=[ + { + "role": "user", + "content": "I need to write a Python function that calculates the factorial of a number. Can you help me write and run it?. You should use the tool to run the code.", + } + ], + tools=[ + { + "type": "function", + "function": { + "name": "run_python_code", + "description": "Execute Python code and return the result", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute", + } + }, + "required": ["code"], + }, + }, + } + ], + tool_choice="auto", + stream=True, + ) + + content_chunks = [] + tool_calls = [] + chunk_count = 0 + + for chunk in stream: + chunk_count += 1 + if chunk.choices and len(chunk.choices) > 0: + delta = chunk.choices[0].delta + + # Log what we see in each chunk + has_content = delta.content is not None + has_tool_calls = delta.tool_calls is not None + + if ( + chunk_count % 50 == 0 or has_tool_calls + ): # Log every 50th chunk or any chunk with tool calls + logger.info( + f"Chunk {chunk_count}: content={has_content}, tool_calls={has_tool_calls}" + ) + if has_tool_calls: + logger.info(f" Tool calls in chunk: {delta.tool_calls}") + + # Collect text content + if delta.content: + content_chunks.append(delta.content) + + # Collect tool calls + if delta.tool_calls: + for tool_call in delta.tool_calls: + # Extend or create tool call entries + while len(tool_calls) <= tool_call.index: + tool_calls.append( + { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + ) + + if tool_call.id: + tool_calls[tool_call.index]["id"] = tool_call.id + if tool_call.function: + if tool_call.function.name: + tool_calls[tool_call.index]["function"][ + "name" + ] = tool_call.function.name + if tool_call.function.arguments: + tool_calls[tool_call.index]["function"][ + "arguments" + ] += tool_call.function.arguments + + full_content = "".join(content_chunks) + logger.info(f"Streaming response from coding-model with tools: {full_content}") + logger.info(f"Tool calls collected: {len(tool_calls)}") + + if tool_calls: + for i, tc in enumerate(tool_calls): + logger.info(f" Tool call {i}: {tc['function']['name']}") + + # Should get either text response or tool calls for coding assistance + assert ( + full_content or len(tool_calls) > 0 + ), f"Expected text or tool calls. Got text_len={len(full_content)}, tools={len(tool_calls)}"